mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 07:45:35 +08:00
Compare commits
256 Commits
ciflow/tru
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 11681f2457 | |||
| b126065154 | |||
| bc882f8284 | |||
| edd365ed4a | |||
| 1366a2fa55 | |||
| 91f0c5a9da | |||
| 67390692c5 | |||
| 1debfd44fd | |||
| cdf0a9c21f | |||
| 115016f1a2 | |||
| 971e6ca434 | |||
| e8d411e7f7 | |||
| 2e5233d7bd | |||
| 514dd96376 | |||
| 9ae62fcc18 | |||
| ae71b0e163 | |||
| 5b6ff8148d | |||
| 1f7e4343e7 | |||
| b21856f5fc | |||
| 259ba0ecab | |||
| 051f1fe8e3 | |||
| ee387c43fe | |||
| 3a944661d6 | |||
| 56034074ca | |||
| 8def619bbe | |||
| 61883a5787 | |||
| d8ada1ee76 | |||
| fe841a1db4 | |||
| b65829b84f | |||
| b0e0ae97ba | |||
| f44a1ddcb2 | |||
| 184e2cbc89 | |||
| 416421c7c4 | |||
| bd99ae3315 | |||
| ce8672c24f | |||
| 402c465030 | |||
| 573a79fffa | |||
| 4945180468 | |||
| 1df723e6f5 | |||
| f9b81e23e4 | |||
| ffe6cc39c7 | |||
| db1f3f6901 | |||
| 43041f0a43 | |||
| dc00842b81 | |||
| f1a129a6d0 | |||
| fad48ffa62 | |||
| 3e7a66fae1 | |||
| 5f0a563dc8 | |||
| 678915d5f1 | |||
| daed97afff | |||
| 53947adb1f | |||
| c297b02f12 | |||
| bd24774f50 | |||
| 525eb9fab9 | |||
| 7886070fc5 | |||
| 87d17e9dee | |||
| 53422e6bc8 | |||
| c34b743eac | |||
| db250fa895 | |||
| 52231a7974 | |||
| cf71c53eae | |||
| f9caae42ed | |||
| 52a6b5a4cc | |||
| 94f6f79e27 | |||
| 5676de1157 | |||
| 2ca0b3f70a | |||
| b06453c7cf | |||
| f0fa39a7e4 | |||
| b5142f74f9 | |||
| a14452bfce | |||
| 619f329a4b | |||
| 7a48db0809 | |||
| 406f2943d2 | |||
| c3bc56c8b4 | |||
| b2be4d24c0 | |||
| 8d5cceeb6a | |||
| f6331192b4 | |||
| f8d408d24a | |||
| 5a85b6eaf8 | |||
| e3d6896d08 | |||
| 9d9e7c7b1c | |||
| 4c3721fe70 | |||
| 8ef4099313 | |||
| de773364be | |||
| 47da714b8b | |||
| 69ab1f93e4 | |||
| 232baa33b3 | |||
| 6f0182495f | |||
| 7da82b84e2 | |||
| cda7604434 | |||
| 6ca8cc6edf | |||
| bb37483464 | |||
| 2751b1d3c3 | |||
| fe0bb7cf60 | |||
| cf63b212e3 | |||
| 17e70ae459 | |||
| ad7db3617e | |||
| 5320ca3725 | |||
| 3e4faca130 | |||
| 0c2f206ded | |||
| 6cf21fa331 | |||
| cdc8460f2c | |||
| 86130aa2ca | |||
| 9491830c79 | |||
| 04a85b4c21 | |||
| a4437d76f0 | |||
| 7fea9e290b | |||
| 077f1daeaa | |||
| 3ea829a337 | |||
| 3966b5ad05 | |||
| f6a79b2a4a | |||
| 2fcf41dd8e | |||
| 31ccd8f13e | |||
| 59307ca1bc | |||
| c28475db7c | |||
| 74aec83841 | |||
| 52e744d68a | |||
| 3cfbf98ea9 | |||
| 47db55258b | |||
| 50af6f3393 | |||
| e545ba2d34 | |||
| a058bbdd6f | |||
| 2c78080ec0 | |||
| fe6615e397 | |||
| abf31db2cc | |||
| f70df30237 | |||
| 40836fbdef | |||
| 6f0708eae3 | |||
| 389e34330f | |||
| 444f741bd9 | |||
| 81be36a55a | |||
| 0ad5913584 | |||
| 92bd87ad5a | |||
| d282299076 | |||
| 88651e6af0 | |||
| 2111d19fed | |||
| f8ef995c40 | |||
| 264fe678a5 | |||
| 81b43285b9 | |||
| 1ac4d5a4a6 | |||
| b4ccfc39b1 | |||
| e41a65c5f4 | |||
| 25f8922f20 | |||
| 4197251fcd | |||
| f0826ff88f | |||
| e5f494eb9b | |||
| fcd1207841 | |||
| df4eb5104c | |||
| 236e76560a | |||
| 9161c3bb30 | |||
| e8c3e60649 | |||
| 1cf593ab09 | |||
| 8b194d35d0 | |||
| fbf258bb46 | |||
| 2c9473c84c | |||
| 0e9b283b5e | |||
| 6b03cfa431 | |||
| f427647328 | |||
| f682b0d74d | |||
| d1beb0f0f1 | |||
| 580ef872c5 | |||
| cc400925ef | |||
| 9001155ffe | |||
| 4f7fabc043 | |||
| 869dd37eca | |||
| b61cc19dae | |||
| 99fbad06d7 | |||
| 98ff76ed73 | |||
| a2aeb27bee | |||
| e5c65308a6 | |||
| 09998faa18 | |||
| 283b29e2bf | |||
| 91b8fad8b9 | |||
| f5edbfb3c1 | |||
| f434235a8f | |||
| 094cd55b7d | |||
| f5bab128af | |||
| 2b8100381d | |||
| 6e5dba5e1c | |||
| f3b29b9d81 | |||
| ebfcfe912e | |||
| 6861fe634e | |||
| 144a863515 | |||
| eaa298cbe9 | |||
| 7ace17a639 | |||
| efce4e8c20 | |||
| 500fbbeb77 | |||
| 573f2101b2 | |||
| ff302143ea | |||
| f2d143f719 | |||
| 16e588f41d | |||
| 651bacb2eb | |||
| 7d9e5ba5d0 | |||
| 4c90efdf38 | |||
| d592a952fc | |||
| 5be3d2dda5 | |||
| d2299b7838 | |||
| 46fa23c043 | |||
| 60e3695480 | |||
| ab0c1521d1 | |||
| 8bb1c9a8d8 | |||
| af9b4f0734 | |||
| 535912c5fe | |||
| da876b507e | |||
| c3665e33cf | |||
| 7664c7b8c2 | |||
| 12db8395b2 | |||
| 2111a9cf7a | |||
| b84eac19cb | |||
| b991793cb1 | |||
| 853e6cac6c | |||
| 889a253216 | |||
| 4c52ac8494 | |||
| 8fd9a260ee | |||
| 4a15f67bd1 | |||
| 39ac72e35d | |||
| 4b10d9a042 | |||
| 81079d584d | |||
| ddb70b719b | |||
| 8850e0ec91 | |||
| 13d101b384 | |||
| b662dd6a47 | |||
| b1ec7ceb47 | |||
| 048b4383ea | |||
| 3f85bbd18d | |||
| 76bd87e46e | |||
| be827bbde5 | |||
| 46ab2e6503 | |||
| 10b1549b3e | |||
| c648c13e7f | |||
| 275b500db9 | |||
| 08842bbbe6 | |||
| 751ec21dae | |||
| b98214f1d9 | |||
| a00ca04fea | |||
| 4392f7dc26 | |||
| 375940d2c8 | |||
| 110f8ce653 | |||
| 09a7f8d900 | |||
| 60ee217e25 | |||
| 749c9d2faf | |||
| f746ca4628 | |||
| e156940989 | |||
| a7f61aaeea | |||
| eac61d3e73 | |||
| a008d5e1c0 | |||
| b30da6f6c9 | |||
| f466f68388 | |||
| 838fcbe179 | |||
| e7818b6c52 | |||
| a69b16441e | |||
| 736ab6ff7c | |||
| 7063939a93 | |||
| e8de32034f | |||
| 23f9555cc0 | |||
| c4cdf38750 |
@ -260,8 +260,8 @@ case "$tag" in
|
||||
HALIDE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
|
||||
CUDA_VERSION=13.0.0
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
|
||||
CUDA_VERSION=12.8.1
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=11
|
||||
PALLAS=yes
|
||||
|
||||
@ -30,7 +30,6 @@ into a tarball, with the following structure:
|
||||
More specifically, `build_magma.sh` copies over the relevant files from the `package_files` directory depending on the ROCm version.
|
||||
Outputted binaries should be in the `output` folder.
|
||||
|
||||
|
||||
## Pushing
|
||||
|
||||
Packages can be uploaded to an S3 bucket using:
|
||||
|
||||
@ -168,14 +168,16 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/compiler/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/umf/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Enable XCCL build
|
||||
export USE_XCCL=1
|
||||
export USE_MPI=0
|
||||
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
|
||||
export USE_KINETO=0
|
||||
export TORCH_XPU_ARCH_LIST=pvc
|
||||
fi
|
||||
|
||||
|
||||
@ -96,7 +96,6 @@ function pip_build_and_install() {
|
||||
python3 -m pip wheel \
|
||||
--no-build-isolation \
|
||||
--no-deps \
|
||||
--no-use-pep517 \
|
||||
-w "${wheel_dir}" \
|
||||
"${build_target}"
|
||||
fi
|
||||
|
||||
@ -208,6 +208,8 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Check XPU status before testing
|
||||
timeout 30 xpu-smi discovery || true
|
||||
fi
|
||||
|
||||
2
.github/actionlint.yaml
vendored
2
.github/actionlint.yaml
vendored
@ -63,7 +63,7 @@ self-hosted-runner:
|
||||
- linux.rocm.gpu.gfx942.1
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
- rocm-docker
|
||||
- linux.rocm.gfx942.docker-cache
|
||||
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
|
||||
- macos-m1-stable
|
||||
- macos-m1-14
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
|
||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
||||
|
||||
13
.github/labeler.yml
vendored
13
.github/labeler.yml
vendored
@ -165,3 +165,16 @@
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
"ciflow/mps":
|
||||
- aten/src/ATen/mps/**
|
||||
- aten/src/ATen/native/mps/**
|
||||
- torch/_inductor/codegen/mps.py
|
||||
- test/test_mps.py
|
||||
- test/inductor/test_mps_basic.py
|
||||
|
||||
"ciflow/h100-symm-mem":
|
||||
- torch/csrc/distributed/c10d/symm_mem/**
|
||||
- torch/distributed/_symmetric_memory/**
|
||||
- test/distributed/**/*mem*
|
||||
- test/distributed/**/*mem*/**
|
||||
|
||||
3
.github/scripts/lintrunner.sh
vendored
3
.github/scripts/lintrunner.sh
vendored
@ -34,6 +34,9 @@ python3 torch/utils/data/datapipes/gen_pyi.py
|
||||
# Also check generated pyi files
|
||||
find torch -name '*.pyi' -exec git add --force -- "{}" +
|
||||
|
||||
# Print current environment
|
||||
python3 -m pip freeze
|
||||
|
||||
RC=0
|
||||
# Run lintrunner on all files
|
||||
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then
|
||||
|
||||
1
.github/workflows/b200-distributed.yml
vendored
1
.github/workflows/b200-distributed.yml
vendored
@ -37,7 +37,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
|
||||
1
.github/workflows/b200-symm-mem.yml
vendored
1
.github/workflows/b200-symm-mem.yml
vendored
@ -37,7 +37,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
|
||||
18
.github/workflows/docker-builds.yml
vendored
18
.github/workflows/docker-builds.yml
vendored
@ -67,7 +67,7 @@ jobs:
|
||||
pytorch-linux-jammy-py3.10-gcc11,
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-noble-xpu-n-py3,
|
||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
||||
@ -119,6 +119,22 @@ jobs:
|
||||
with:
|
||||
docker-image: ${{ steps.build-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Generate output
|
||||
if: contains(matrix.docker-image-name, 'rocm')
|
||||
id: generate_output
|
||||
run: |
|
||||
docker_image_name="${{ matrix.docker-image-name }}"
|
||||
docker_image_tag="${{ steps.build-docker-image.outputs.docker-image }}"
|
||||
echo "${docker_image_name}=${docker_image_tag}" >> docker-builds-output-${docker_image_name}.txt
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4.4.0
|
||||
if: contains(matrix.docker-image-name, 'rocm')
|
||||
with:
|
||||
name: docker-builds-artifacts-${{ matrix.docker-image-name }}
|
||||
retention-days: 14
|
||||
path: ./docker-builds-output-${{ matrix.docker-image-name }}.txt
|
||||
|
||||
- uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
|
||||
name: Push to https://ghcr.io/
|
||||
id: push-to-ghcr-io
|
||||
|
||||
55
.github/workflows/docker-cache-mi300.yml
vendored
55
.github/workflows/docker-cache-mi300.yml
vendored
@ -1,55 +0,0 @@
|
||||
name: docker-cache-mi300
|
||||
|
||||
on:
|
||||
# run every 6 hours
|
||||
schedule:
|
||||
- cron: 0 0,6,12,18 * * *
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
runs-on: rocm-docker
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
no-sudo: true
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
push: false
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Tar and upload to S3 bucket
|
||||
run: |
|
||||
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress
|
||||
105
.github/workflows/docker-cache-rocm.yml
vendored
Normal file
105
.github/workflows/docker-cache-rocm.yml
vendored
Normal file
@ -0,0 +1,105 @@
|
||||
name: docker-cache-rocm
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [docker-builds]
|
||||
branches: [main, release]
|
||||
types:
|
||||
- completed
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
actions: read
|
||||
|
||||
jobs:
|
||||
download-docker-builds-artifacts:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: download-docker-builds-artifacts
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
|
||||
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
|
||||
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
|
||||
steps:
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4.1.7
|
||||
with:
|
||||
run-id: ${{ github.event.workflow_run.id }}
|
||||
path: ./docker-builds-artifacts
|
||||
merge-multiple: true
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Process artifacts
|
||||
id: process-artifacts
|
||||
run: |
|
||||
ls -R ./docker-builds-artifacts
|
||||
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
|
||||
cat "${GITHUB_OUTPUT}"
|
||||
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
needs: download-docker-builds-artifacts
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
runner: [linux.rocm.gfx942.docker-cache]
|
||||
docker-image: [
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
|
||||
]
|
||||
runs-on: "${{ matrix.runner }}"
|
||||
steps:
|
||||
- name: debug
|
||||
run: |
|
||||
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
|
||||
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Generate ghrc.io tag
|
||||
id: ghcr-io-tag
|
||||
run: |
|
||||
ecr_image="${{ matrix.docker-image }}"
|
||||
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
|
||||
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
|
||||
|
||||
- name: Save as tarball
|
||||
run: |
|
||||
docker_image_tag=${{ matrix.docker-image }}
|
||||
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
|
||||
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
|
||||
ref_name=${{ github.event.workflow_run.head_branch }}
|
||||
if [[ $ref_name =~ "release/" ]]; then
|
||||
ref_suffix="release"
|
||||
elif [[ $ref_name == "main" ]]; then
|
||||
ref_suffix="main"
|
||||
else
|
||||
echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
|
||||
fi
|
||||
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
|
||||
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
|
||||
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
|
||||
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar
|
||||
1
.github/workflows/h100-distributed.yml
vendored
1
.github/workflows/h100-distributed.yml
vendored
@ -37,7 +37,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: "linux.c7i.12xlarge"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '9.0'
|
||||
|
||||
2
.github/workflows/inductor-rocm-mi200.yml
vendored
2
.github/workflows/inductor-rocm-mi200.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: inductor-rocm
|
||||
name: inductor-rocm-mi200
|
||||
|
||||
on:
|
||||
schedule:
|
||||
|
||||
4
.github/workflows/inductor-unittest.yml
vendored
4
.github/workflows/inductor-unittest.yml
vendored
@ -86,8 +86,8 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.8xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
|
||||
8
.github/workflows/nightly.yml
vendored
8
.github/workflows/nightly.yml
vendored
@ -5,9 +5,11 @@ on:
|
||||
- cron: 0 0 * * *
|
||||
push:
|
||||
tags:
|
||||
# NOTE: Doc build pipelines should only get triggered on release candidate builds
|
||||
# Release candidate tags look like: v1.11.0-rc1
|
||||
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
|
||||
# NOTE: Doc build pipelines should only get triggered on:
|
||||
# Major or minor release candidates builds
|
||||
- v[0-9]+.[0-9]+.0+-rc[0-9]+
|
||||
# Final RC for major, minor and patch releases
|
||||
- v[0-9]+.[0-9]+.[0-9]+
|
||||
- ciflow/nightly/*
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
2
.github/workflows/rocm-mi200.yml
vendored
2
.github/workflows/rocm-mi200.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: rocm
|
||||
name: rocm-mi200
|
||||
|
||||
on:
|
||||
push:
|
||||
|
||||
3
.github/workflows/test-b200.yml
vendored
3
.github/workflows/test-b200.yml
vendored
@ -52,7 +52,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
@ -73,4 +72,4 @@ jobs:
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
||||
secrets: inherit
|
||||
|
||||
1
.github/workflows/test-h100.yml
vendored
1
.github/workflows/test-h100.yml
vendored
@ -41,7 +41,6 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '9.0'
|
||||
|
||||
@ -186,6 +186,8 @@ include_patterns = [
|
||||
'aten/src/ATen/native/nested/cuda/*.h',
|
||||
'aten/src/ATen/native/nested/*.cpp',
|
||||
'aten/src/ATen/native/nested/*.h',
|
||||
'aten/src/ATen/xpu/**/*.h',
|
||||
'aten/src/ATen/xpu/**/*.cpp',
|
||||
'c10/**/*.cpp',
|
||||
'c10/**/*.h',
|
||||
'torch/*.h',
|
||||
|
||||
@ -736,6 +736,44 @@ if(NOT DEFINED USE_BLAS)
|
||||
set(USE_BLAS ON)
|
||||
endif()
|
||||
|
||||
# Prioritized Text Linker Optimization
|
||||
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
||||
|
||||
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
||||
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE}
|
||||
${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py
|
||||
--filein "${LINKER_SCRIPT_FILE_IN}"
|
||||
--fout "${LINKER_SCRIPT_FILE_OUT}"
|
||||
RESULT_VARIABLE _gen_result
|
||||
OUTPUT_VARIABLE _gen_output
|
||||
ERROR_VARIABLE _gen_error
|
||||
)
|
||||
|
||||
if(NOT _gen_result EQUAL 0)
|
||||
message(FATAL_ERROR
|
||||
"Failed to generate linker script:\n${_gen_output}\n${_gen_error}")
|
||||
endif()
|
||||
|
||||
append_cxx_flag_if_supported("-ffunction-sections" CMAKE_CXX_FLAGS)
|
||||
append_cxx_flag_if_supported("-fdata-sections" CMAKE_CXX_FLAGS)
|
||||
append_c_flag_if_supported("-ffunction-sections" CMAKE_C_FLAGS)
|
||||
append_c_flag_if_supported("-fdata-sections" CMAKE_C_FLAGS)
|
||||
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
||||
set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
||||
|
||||
else()
|
||||
if(LINUX AND CPU_AARCH64)
|
||||
message(WARNING [[
|
||||
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
]])
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Build libtorch mobile library, which contains ATen/TH ops and native support
|
||||
# for TorchScript model, but doesn't contain not-yet-unified caffe2 ops;
|
||||
if(INTERN_BUILD_MOBILE)
|
||||
@ -1402,9 +1440,6 @@ if(BUILD_JNI)
|
||||
add_subdirectory(android/pytorch_android)
|
||||
endif()
|
||||
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
# Parse custom debug info
|
||||
if(DEFINED USE_CUSTOM_DEBINFO)
|
||||
string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}")
|
||||
@ -1444,56 +1479,5 @@ if(BUILD_BUNDLE_PTXAS AND USE_CUDA)
|
||||
DESTINATION "${CMAKE_INSTALL_BINDIR}")
|
||||
endif()
|
||||
|
||||
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
||||
add_compile_options(
|
||||
$<$<COMPILE_LANGUAGE:C,CXX>:-ffunction-sections>
|
||||
$<$<COMPILE_LANGUAGE:C,CXX>:-fdata-sections>
|
||||
)
|
||||
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
||||
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${LINKER_SCRIPT_FILE_OUT}"
|
||||
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py --filein "${LINKER_SCRIPT_FILE_IN}" --fout "${LINKER_SCRIPT_FILE_OUT}"
|
||||
DEPENDS ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py "${LINKER_SCRIPT_FILE_IN}"
|
||||
COMMENT "Generating prioritized text linker files"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_linker_script DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||
|
||||
if(BUILD_PYTHON)
|
||||
set(LINKER_OPT_TARGETS torch_python)
|
||||
endif()
|
||||
|
||||
if(NOT BUILD_LIBTORCHLESS)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_cpu c10)
|
||||
if(USE_CUDA)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_cuda c10_cuda)
|
||||
endif()
|
||||
if(USE_XPU)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_xpu c10_xpu)
|
||||
endif()
|
||||
if(USE_ROCM)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_hip c10_hip)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
foreach(tgt IN LISTS LINKER_OPT_TARGETS)
|
||||
if(TARGET ${tgt})
|
||||
add_dependencies("${tgt}" generate_linker_script)
|
||||
target_link_options_if_supported(${tgt} "-T,${LINKER_SCRIPT_FILE_OUT}")
|
||||
set_property(TARGET ${tgt} APPEND PROPERTY LINK_DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||
else()
|
||||
message(WARNING "Requested target '${tgt}' for linker script optimization was not found.")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
else()
|
||||
if(LINUX AND CPU_AARCH64)
|
||||
message(WARNING [[
|
||||
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
]])
|
||||
endif()
|
||||
endif()
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
2
LICENSE
2
LICENSE
@ -37,7 +37,7 @@ Copyright (c) 2024 Tri Dao.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Arm:
|
||||
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
|
||||
Copyright (c) 2021, 2023-2025 Arm Limited and/or its affiliates
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
|
||||
@ -18,6 +18,8 @@ Please report security issues using https://github.com/pytorch/pytorch/security/
|
||||
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
**Note on crashes and out of bounds access**: PyTorch is a computational framework that performs operations on behalf of the caller. Like many low-level libraries, PyTorch generally does not validate all inputs to every function—the responsibility for providing valid arguments lies with the calling code. While crashes and out of bounds memory access should be reported as bugs, they are generally not considered security vulnerabilities in PyTorch's threat model.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
@ -94,6 +94,11 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -55,14 +55,6 @@ struct numeric_limits<int8_t> {
|
||||
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<uint16_t> {
|
||||
static inline __host__ __device__ uint16_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint16_t max() { return UINT16_MAX; }
|
||||
static inline __host__ __device__ uint16_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<int16_t> {
|
||||
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
|
||||
@ -71,14 +63,6 @@ struct numeric_limits<int16_t> {
|
||||
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<uint32_t> {
|
||||
static inline __host__ __device__ uint32_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint32_t max() { return UINT32_MAX; }
|
||||
static inline __host__ __device__ uint32_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<int32_t> {
|
||||
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
|
||||
@ -87,21 +71,6 @@ struct numeric_limits<int32_t> {
|
||||
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<uint64_t> {
|
||||
#ifdef _MSC_VER
|
||||
static inline __host__ __device__ uint64_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint64_t max() { return _UI64_MAX; }
|
||||
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; }
|
||||
#else
|
||||
static inline __host__ __device__ uint64_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint64_t max() { return UINT64_MAX; }
|
||||
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; }
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<int64_t> {
|
||||
#ifdef _MSC_VER
|
||||
|
||||
@ -157,6 +157,8 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
||||
DispatchKey::Negative,
|
||||
DispatchKey::Conjugate,
|
||||
DispatchKey::XLA,
|
||||
DispatchKey::XPU,
|
||||
DispatchKey::HPU,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
|
||||
@ -440,7 +440,7 @@ bool MPSHeapAllocatorImpl::release_cached_buffers() {
|
||||
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
|
||||
m_mutex.unlock();
|
||||
auto stream = getDefaultMPSStream();
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
});
|
||||
m_mutex.lock();
|
||||
|
||||
@ -110,6 +110,9 @@ class TORCH_API MPSStream {
|
||||
return _stream;
|
||||
}
|
||||
|
||||
MTLBuffer_t getErrorBuffer();
|
||||
void checkLastError();
|
||||
|
||||
private:
|
||||
Stream _stream;
|
||||
MTLCommandQueue_t _commandQueue = nil;
|
||||
@ -121,6 +124,8 @@ class TORCH_API MPSStream {
|
||||
dispatch_queue_t _serialQueue = nullptr;
|
||||
// CommitAndContinue is enabled by default
|
||||
bool _enableCommitAndContinue = true;
|
||||
// Buffer that contains last raised error
|
||||
MTLBuffer_t _errorBuffer = nil;
|
||||
|
||||
// use synchronize() to access any of these commit functions outside MPSStream
|
||||
void commit();
|
||||
@ -155,4 +160,7 @@ class TORCH_API MPSStreamImpl {
|
||||
MPSStreamImpl();
|
||||
};
|
||||
|
||||
#ifdef __OBJC__
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
||||
#endif
|
||||
} // namespace at::mps
|
||||
|
||||
@ -3,13 +3,13 @@
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <c10/metal/error.h>
|
||||
|
||||
@interface MPSGraphExecutionDescriptor ()
|
||||
@property(readwrite, atomic) BOOL enableCommitAndContinue;
|
||||
@end
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSStream
|
||||
//-----------------------------------------------------------------
|
||||
@ -30,6 +30,10 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
|
||||
// Choose level which optimizes for GPU
|
||||
_compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
|
||||
_executionDescriptor.compilationDescriptor = _compilationDescriptor;
|
||||
|
||||
_errorBuffer = [MPSDevice::getInstance()->device() newBufferWithLength:sizeof(c10::metal::ErrorMessages)
|
||||
options:MTLResourceStorageModeShared];
|
||||
std::memset([_errorBuffer contents], 0, 1024);
|
||||
}
|
||||
|
||||
MPSStream::~MPSStream() {
|
||||
@ -38,6 +42,8 @@ MPSStream::~MPSStream() {
|
||||
[_executionDescriptor release];
|
||||
[_compilationDescriptor release];
|
||||
_executionDescriptor = nil;
|
||||
[_errorBuffer release];
|
||||
_errorBuffer = nil;
|
||||
_compilationDescriptor = nil;
|
||||
|
||||
assert(_commandBuffer == nil);
|
||||
@ -104,6 +110,7 @@ void MPSStream::commitAndWait() {
|
||||
[_prevCommandBuffer waitUntilCompleted];
|
||||
[_prevCommandBuffer release];
|
||||
_prevCommandBuffer = nil;
|
||||
checkLastError();
|
||||
}
|
||||
|
||||
if (_commandBuffer) {
|
||||
@ -111,6 +118,7 @@ void MPSStream::commitAndWait() {
|
||||
[_commandBuffer waitUntilCompleted];
|
||||
[_commandBuffer release];
|
||||
_commandBuffer = nil;
|
||||
checkLastError();
|
||||
}
|
||||
}
|
||||
|
||||
@ -153,7 +161,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t
|
||||
if (length == 0) {
|
||||
return;
|
||||
}
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
endKernelCoalescing();
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
@ -183,7 +191,7 @@ void MPSStream::copy(id<MTLBuffer> srcBuffer,
|
||||
size_t dstOffset,
|
||||
uint64_t profileId,
|
||||
SyncType syncType) {
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
endKernelCoalescing();
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
@ -236,7 +244,7 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
|
||||
auto& profiler = getMPSProfiler();
|
||||
const bool isGraphProfilingEnabled = profiler.isOperationProfilingEnabled();
|
||||
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
||||
endKernelCoalescing();
|
||||
if (isGraphProfilingEnabled) {
|
||||
// this function call is only relevant for interval-based Signposts
|
||||
@ -266,6 +274,24 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
|
||||
});
|
||||
}
|
||||
|
||||
id<MTLBuffer> MPSStream::getErrorBuffer() {
|
||||
return _errorBuffer;
|
||||
}
|
||||
|
||||
void MPSStream::checkLastError() {
|
||||
auto msgs = reinterpret_cast<c10::metal::ErrorMessages*>([_errorBuffer contents]);
|
||||
const auto& msg = msgs->msg[0];
|
||||
if (!msgs) {
|
||||
return;
|
||||
}
|
||||
unsigned int count = 0;
|
||||
std::swap(count, msgs->count);
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
throw c10::AcceleratorError({msg.func, msg.file, msg.line}, 1, msg.message);
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSStreamImpl
|
||||
//-----------------------------------------------------------------
|
||||
@ -289,4 +315,19 @@ MPSStream* getDefaultMPSStream() {
|
||||
return MPSStreamImpl::getInstance();
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
|
||||
__block std::optional<std::exception_ptr> block_exception;
|
||||
dispatch_sync(queue, ^() {
|
||||
try {
|
||||
block();
|
||||
} catch (...) {
|
||||
block_exception = std::current_exception();
|
||||
}
|
||||
});
|
||||
if (block_exception) {
|
||||
std::rethrow_exception(*block_exception);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::mps
|
||||
|
||||
@ -142,6 +142,7 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
|
||||
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
|
||||
auto batch_sizes_t = _batch_sizes.contiguous();
|
||||
checkLongTensor(batch_sizes_t);
|
||||
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
|
||||
|
||||
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
|
||||
int64_t max_batch_size = batch_sizes[0];
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
#include <ATen/ops/_aminmax_native.h>
|
||||
#include <ATen/ops/_assert_async_native.h>
|
||||
#include <ATen/ops/_assert_scalar_native.h>
|
||||
#include <ATen/ops/_async_error_native.h>
|
||||
#include <ATen/ops/_functional_assert_async_native.h>
|
||||
#include <ATen/ops/_functional_assert_scalar_native.h>
|
||||
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
||||
@ -479,6 +480,14 @@ Tensor isfinite(const Tensor& self) {
|
||||
});
|
||||
}
|
||||
|
||||
void _async_error(std::string_view msg) {
|
||||
TORCH_CHECK(0, msg);
|
||||
}
|
||||
|
||||
void _async_error_meta(std::string_view msg) {
|
||||
// Do NOT error, it's an async error!
|
||||
}
|
||||
|
||||
void _assert_async_cpu(const Tensor& self) {
|
||||
TORCH_CHECK(
|
||||
native::is_nonzero(self),
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
@ -79,12 +78,12 @@ void min_all_kernel_impl(Tensor& result, const Tensor& input) {
|
||||
reduce_all_impl<int64_t>(result, input, upper_bound<int64_t>(),
|
||||
[=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); });
|
||||
} else {
|
||||
AT_DISPATCH_V2(input.scalar_type(), "min_all", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
reduce_all_impl_vec<scalar_t>(result, input, upper_bound<scalar_t>(),
|
||||
[=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[=](Vec a, Vec b) -> Vec { return minimum(a, b); });
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -104,12 +103,12 @@ void max_all_kernel_impl(Tensor& result, const Tensor& input) {
|
||||
reduce_all_impl<int64_t>(result, input, lower_bound<int64_t>(),
|
||||
[=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); });
|
||||
} else {
|
||||
AT_DISPATCH_V2(input.scalar_type(), "max_all", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
reduce_all_impl_vec<scalar_t>(result, input, lower_bound<scalar_t>(),
|
||||
[=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); },
|
||||
[=](Vec a, Vec b) -> Vec { return maximum(a, b); });
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,7 +199,7 @@ void aminmax_allreduce_kernel(
|
||||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_V2(input.scalar_type(), "aminmax_cpu", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
|
||||
reduce_all_impl_vec_two_outputs<scalar_t>(
|
||||
@ -215,7 +214,7 @@ void aminmax_allreduce_kernel(
|
||||
[=](Vec a, Vec b) -> Vec { return minimum(a, b); },
|
||||
[=](Vec a, Vec b) -> Vec { return maximum(a, b); }
|
||||
);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
@ -348,35 +347,34 @@ struct MinValuesOps: public at::native::MinOps<scalar_t> {
|
||||
};
|
||||
|
||||
void min_values_kernel_impl(TensorIterator& iter) {
|
||||
// This case is special because of Vectorized<int64_t> does not
|
||||
// handle upper_bound<int64_t>().
|
||||
// See: https://github.com/pytorch/pytorch/issues/43254
|
||||
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
MinValuesOps<scalar_t>{},
|
||||
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
|
||||
}), kLong, kUInt64);
|
||||
if (iter.dtype() == kLong) {
|
||||
// This case is special because of Vectorized<int64_t> does not
|
||||
// handle upper_bound<int64_t>().
|
||||
// See: https://github.com/pytorch/pytorch/issues/43254
|
||||
using scalar_t = int64_t;
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
MinValuesOps<scalar_t>{},
|
||||
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
|
||||
return;
|
||||
}
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
|
||||
static_cast<double>(upper_bound<scalar_t>()));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void max_values_kernel_impl(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "max_values_cpu", AT_WRAP([&iter] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return maximum(a, b); },
|
||||
lower_bound<scalar_t>());
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void argmax_kernel_impl(TensorIterator &iter) {
|
||||
|
||||
@ -11,7 +11,6 @@
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
@ -107,7 +106,7 @@ void min_kernel_impl(
|
||||
bool keepdim) {
|
||||
int64_t self_dim_size = ensure_nonempty_size(self, dim);
|
||||
|
||||
AT_DISPATCH_V2(self.scalar_type(), "min_cpu", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
|
||||
compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
|
||||
scalar_t* result_data, int64_t* indice_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
@ -129,7 +128,7 @@ void min_kernel_impl(
|
||||
*indice_data = index;
|
||||
}
|
||||
);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool);
|
||||
});
|
||||
}
|
||||
|
||||
void max_kernel_impl(
|
||||
@ -140,7 +139,7 @@ void max_kernel_impl(
|
||||
bool keepdim) {
|
||||
int64_t self_dim_size = ensure_nonempty_size(self, dim);
|
||||
|
||||
AT_DISPATCH_V2(self.scalar_type(), "max_cpu", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
|
||||
compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
|
||||
scalar_t* result_data, int64_t* indice_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
@ -162,7 +161,7 @@ void max_kernel_impl(
|
||||
*indice_data = index;
|
||||
}
|
||||
);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool);
|
||||
});
|
||||
}
|
||||
|
||||
void aminmax_kernel(
|
||||
@ -187,7 +186,7 @@ void aminmax_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
AT_DISPATCH_V2(self.scalar_type(), "aminmax_cpu", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] {
|
||||
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
|
||||
scalar_t* min_result_data, scalar_t* max_result_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
@ -210,7 +209,7 @@ void aminmax_kernel(
|
||||
*max_result_data = max_number;
|
||||
}
|
||||
);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half);
|
||||
});
|
||||
}
|
||||
|
||||
void where_kernel_impl(TensorIterator &iter) {
|
||||
|
||||
@ -669,9 +669,12 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
// On non CK system(w/ ROCm), make sure use_fast_path is false
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
@ -680,7 +683,11 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/ReduceAllOps.h>
|
||||
@ -29,22 +28,22 @@ void _min_max_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
}
|
||||
|
||||
void aminmax_allreduce_launch_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(
|
||||
iter.input_dtype(), "aminmax_all_cuda", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] {
|
||||
_min_max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void aminmax_launch_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(
|
||||
iter.input_dtype(), "aminmax_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MinMaxOps<scalar_t, scalar_t, int32_t>{},
|
||||
thrust::pair<scalar_t, scalar_t>(
|
||||
at::numeric_limits<scalar_t>::upper_bound(),
|
||||
at::numeric_limits<scalar_t>::lower_bound()));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/ReduceAllOps.h>
|
||||
@ -34,27 +33,27 @@ void max_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
}
|
||||
|
||||
void max_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(), "max_values_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void max_launch_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(
|
||||
iter.input_dtype(), "max_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MaxOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(
|
||||
at::numeric_limits<scalar_t>::lower_bound(), 0));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void max_all_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "max_all_cuda", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda)
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
#include <ATen/NumericUtils.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/cuda/NumericLimits.cuh>
|
||||
|
||||
@ -34,24 +33,24 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
}
|
||||
|
||||
void min_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void min_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MinOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::upper_bound(), 0));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
void min_all_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "min_all_cuda", AT_WRAP([&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
});
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda)
|
||||
|
||||
@ -40,8 +40,6 @@ using namespace at::mps;
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
||||
|
||||
struct MPSScalar {
|
||||
id<MTLBuffer> getMTLBuffer() const {
|
||||
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());
|
||||
|
||||
@ -53,21 +53,6 @@
|
||||
@end
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
|
||||
__block std::optional<std::exception_ptr> block_exception;
|
||||
dispatch_sync(queue, ^() {
|
||||
try {
|
||||
block();
|
||||
} catch (...) {
|
||||
block_exception = std::current_exception();
|
||||
}
|
||||
});
|
||||
if (block_exception) {
|
||||
std::rethrow_exception(*block_exception);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes distance from lowest to highest element offset in given tensor.
|
||||
*/
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include <c10/metal/atomic.h>
|
||||
#include <c10/metal/error.h>
|
||||
#include <c10/metal/indexing.h>
|
||||
#include <metal_stdlib>
|
||||
|
||||
@ -31,10 +32,24 @@ OffsetT index_apply_indices(
|
||||
constant IndexAB* indices,
|
||||
constant int64_t* sizes,
|
||||
constant int64_t* strides,
|
||||
uint num_indices) {
|
||||
uint num_indices,
|
||||
thread bool& error,
|
||||
device ErrorMessages* error_buf) {
|
||||
OffsetT rc = offs.x;
|
||||
for (uint i = 0; i < num_indices; i++) {
|
||||
auto idx = indices[i].indexArray[offs.y];
|
||||
if (idx < -sizes[i] || idx >= sizes[i]) {
|
||||
TORCH_REPORT_ERROR(
|
||||
error_buf,
|
||||
"index ",
|
||||
idx,
|
||||
" is out of bounds for dimension ",
|
||||
i,
|
||||
" with size ",
|
||||
sizes[i]);
|
||||
error = true;
|
||||
break;
|
||||
}
|
||||
if (idx < 0) {
|
||||
idx += sizes[i];
|
||||
}
|
||||
@ -55,6 +70,7 @@ kernel void index_select(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const auto ndim = ndim_nindices_numel.x;
|
||||
const auto num_indices = ndim_nindices_numel.y;
|
||||
@ -65,8 +81,19 @@ kernel void index_select(
|
||||
indices_strides,
|
||||
ndim,
|
||||
thread_index);
|
||||
bool error = false;
|
||||
auto input_offs = index_apply_indices<OffsetT>(
|
||||
offs.yz, indices, index_sizes, index_strides, num_indices);
|
||||
offs.yz,
|
||||
indices,
|
||||
index_sizes,
|
||||
index_strides,
|
||||
num_indices,
|
||||
error,
|
||||
error_buffer);
|
||||
if (error) {
|
||||
output[offs.x / sizeof(T)] = 0;
|
||||
return;
|
||||
}
|
||||
output[offs.x / sizeof(T)] = input[input_offs / sizeof(T)];
|
||||
}
|
||||
|
||||
@ -82,7 +109,9 @@ inline void index_put_impl(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index) {
|
||||
bool error = false;
|
||||
const auto ndim = ndim_nindices_numel.x;
|
||||
const auto num_indices = ndim_nindices_numel.y;
|
||||
const auto offs = index_get_offsets(
|
||||
@ -93,7 +122,16 @@ inline void index_put_impl(
|
||||
ndim,
|
||||
thread_index);
|
||||
auto output_offs = index_apply_indices<OffsetT>(
|
||||
offs.xz, indices, index_sizes, index_strides, num_indices);
|
||||
offs.xz,
|
||||
indices,
|
||||
index_sizes,
|
||||
index_strides,
|
||||
num_indices,
|
||||
error,
|
||||
error_buffer);
|
||||
if (error) {
|
||||
return;
|
||||
}
|
||||
output[output_offs / sizeof(T)] = input[offs.y / sizeof(T)];
|
||||
}
|
||||
|
||||
@ -109,6 +147,7 @@ kernel void index_put(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
index_put_impl(
|
||||
output,
|
||||
@ -121,6 +160,7 @@ kernel void index_put(
|
||||
index_sizes,
|
||||
index_strides,
|
||||
ndim_nindices_numel,
|
||||
error_buffer,
|
||||
thread_index);
|
||||
}
|
||||
|
||||
@ -136,6 +176,7 @@ kernel void index_put_serial(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
(void)thread_index; // Suppress unused vairable varning
|
||||
for (uint idx = 0; idx < ndim_nindices_numel.z; ++idx) {
|
||||
@ -150,6 +191,7 @@ kernel void index_put_serial(
|
||||
index_sizes,
|
||||
index_strides,
|
||||
ndim_nindices_numel,
|
||||
error_buffer,
|
||||
idx);
|
||||
}
|
||||
}
|
||||
@ -166,6 +208,7 @@ kernel void index_put_accumulate(
|
||||
constant int64_t* index_sizes,
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
device ErrorMessages* error_buffer,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const auto ndim = ndim_nindices_numel.x;
|
||||
const auto num_indices = ndim_nindices_numel.y;
|
||||
@ -176,8 +219,18 @@ kernel void index_put_accumulate(
|
||||
indices_strides,
|
||||
ndim,
|
||||
thread_index);
|
||||
bool error = false;
|
||||
auto output_offs = index_apply_indices<OffsetT>(
|
||||
offs.xz, indices, index_sizes, index_strides, num_indices);
|
||||
offs.xz,
|
||||
indices,
|
||||
index_sizes,
|
||||
index_strides,
|
||||
num_indices,
|
||||
error,
|
||||
error_buffer);
|
||||
if (error) {
|
||||
return;
|
||||
}
|
||||
AtomicType<T>::atomic_add(
|
||||
reinterpret_cast<device AtomicType_t<T>*>(output),
|
||||
output_offs / sizeof(T),
|
||||
@ -197,6 +250,7 @@ kernel void index_put_accumulate(
|
||||
constant int64_t* index_sizes, \
|
||||
constant int64_t* index_strides, \
|
||||
constant uint4& ndim_nindices_numel, \
|
||||
device ErrorMessages* error_buffer, \
|
||||
uint thread_index [[thread_position_in_grid]])
|
||||
|
||||
#define REGISTER_INDEX_OP_ALL_DTYPES(OP_NAME) \
|
||||
|
||||
@ -220,7 +220,7 @@ Tensor _embedding_bag_dense_backward_mps(const Tensor& output_grad,
|
||||
auto num_threads = (params.mode == EmbeddingBagMode::MAX) ? output_grad.numel() : num_indices * params.feature_size;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_backward_{}_{}",
|
||||
@ -273,7 +273,7 @@ Tensor _embedding_bag_per_sample_weights_backward_mps(const Tensor& output_grad,
|
||||
auto num_threads = num_indices * feature_size;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_per_sample_weights_backward_{}_{}",
|
||||
|
||||
@ -179,7 +179,8 @@ static void dispatch_index_kernel(TensorIteratorBase& iter,
|
||||
iter.strides(2),
|
||||
index_size,
|
||||
index_stride,
|
||||
ndim_nindiees);
|
||||
ndim_nindiees,
|
||||
mpsStream->getErrorBuffer());
|
||||
mtl_dispatch1DJob(computeEncoder, indexSelectPSO, serial ? 1 : iter.numel());
|
||||
});
|
||||
}
|
||||
@ -299,7 +300,7 @@ static Tensor& nonzero_out_native_mps(const Tensor& self, Tensor& out_) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
});
|
||||
int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();
|
||||
@ -384,7 +385,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
});
|
||||
int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();
|
||||
|
||||
@ -923,7 +923,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(const Tensor& input,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS");
|
||||
@autoreleasepool {
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
// which kernel variant to use based on the normalized axis N size
|
||||
const int N_READS = 4;
|
||||
auto metalType = mps::scalarToMetalTypeString(input);
|
||||
|
||||
@ -192,6 +192,11 @@
|
||||
CompositeExplicitAutograd: _assert_tensor_metadata
|
||||
Meta: _assert_tensor_metadata_meta_symint
|
||||
|
||||
- func: _async_error(str msg) -> ()
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _async_error
|
||||
Meta: _async_error_meta
|
||||
|
||||
- func: _print(str s) -> ()
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _print
|
||||
@ -4292,6 +4297,7 @@
|
||||
dispatch:
|
||||
SparseCPU: sparse_sparse_matmul_cpu
|
||||
SparseCUDA: sparse_sparse_matmul_cuda
|
||||
SparseMPS: sparse_sparse_matmul_mps
|
||||
autogen: _sparse_sparse_matmul.out
|
||||
|
||||
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
@ -9832,7 +9838,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method, function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: erfinv_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
|
||||
tags: pointwise
|
||||
|
||||
@ -9841,7 +9847,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
|
||||
tags: pointwise
|
||||
|
||||
@ -9851,7 +9857,7 @@
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: erfinv_out
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_out
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
|
||||
tags: pointwise
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/for_each.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/gather.h>
|
||||
|
||||
@ -10,6 +10,10 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_coalesce_native.h>
|
||||
#include <ATen/ops/repeat_interleave_native.h>
|
||||
#include <ATen/ops/cumsum.h>
|
||||
#include <ATen/ops/_sparse_sparse_matmul_native.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/add_native.h>
|
||||
@ -888,5 +892,114 @@ static void sparse_mask_intersection_out_mps_kernel(
|
||||
/*coalesce_mask=*/false);
|
||||
}
|
||||
|
||||
Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be sparse COO tensors");
|
||||
TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be on MPS device");
|
||||
TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2,
|
||||
"sparse_sparse_matmul_mps: both inputs must be 2D matrices");
|
||||
TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0,
|
||||
"sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)");
|
||||
TORCH_CHECK(mat1_.size(1) == mat2_.size(0),
|
||||
"mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
|
||||
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
|
||||
"sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(),
|
||||
" does not match mat2 dtype ", mat2_.scalar_type());
|
||||
|
||||
const auto device = mat1_.device();
|
||||
|
||||
auto A = mat1_.coalesce();
|
||||
auto B = mat2_.coalesce();
|
||||
|
||||
const auto I = A.size(0);
|
||||
const auto K = A.size(1);
|
||||
const auto N = B.size(1);
|
||||
|
||||
const auto nnzA = A._nnz();
|
||||
const auto nnzB = B._nnz();
|
||||
|
||||
// Early empty result, return an empty, coalesced tensor
|
||||
if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
const auto computeDtype = at::result_type(mat1_, mat2_);
|
||||
|
||||
auto A_idx = A._indices().contiguous();
|
||||
auto A_val = A._values().to(computeDtype).contiguous();
|
||||
auto A_i = A_idx.select(0, 0).contiguous();
|
||||
auto A_k = A_idx.select(0, 1).contiguous();
|
||||
|
||||
auto B_idx = B._indices().contiguous();
|
||||
auto B_val = B._values().to(computeDtype).contiguous();
|
||||
auto B_k = B_idx.select(0, 0).contiguous();
|
||||
auto B_j = B_idx.select(0, 1).contiguous();
|
||||
|
||||
// csr-style row pointers for B by k (the shared dimension)
|
||||
Tensor row_ptr_B;
|
||||
{
|
||||
auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong));
|
||||
row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong));
|
||||
build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B);
|
||||
}
|
||||
|
||||
auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K);
|
||||
auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K);
|
||||
auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo);
|
||||
|
||||
auto counts = deg_B.index_select(0, A_k);
|
||||
|
||||
const int64_t P = counts.sum().item<int64_t>();
|
||||
if (P == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
auto group_ids = repeat_interleave_mps(counts);
|
||||
|
||||
// exclusive cumsum of counts
|
||||
auto offsets = cumsum(counts, /*dim=*/0).sub(counts);
|
||||
auto offsets_gather = offsets.index_select(0, group_ids);
|
||||
auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather);
|
||||
|
||||
// Map each output element to its source B row and position
|
||||
auto k_per_out = A_k.index_select(0, group_ids);
|
||||
auto start_in_B = row_ptr_B.index_select(0, k_per_out);
|
||||
auto seg_index = start_in_B.add(within);
|
||||
|
||||
// Assemble candidate coo pairs and values
|
||||
auto i_out = A_i.index_select(0, group_ids).contiguous();
|
||||
auto j_out = B_j.index_select(0, seg_index).contiguous();
|
||||
auto vA_out = A_val.index_select(0, group_ids).contiguous();
|
||||
auto vB_out = B_val.index_select(0, seg_index).contiguous();
|
||||
auto v_out = vA_out.mul(vB_out);
|
||||
|
||||
// build (2, P) indices
|
||||
auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous();
|
||||
out_indices.select(0, 0).copy_(i_out);
|
||||
out_indices.select(0, 1).copy_(j_out);
|
||||
|
||||
auto result = _sparse_coo_tensor_unsafe(
|
||||
out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype));
|
||||
|
||||
result = result.coalesce();
|
||||
|
||||
if (result.scalar_type() != mat1_.scalar_type()) {
|
||||
auto cast_vals = result._values().to(mat1_.scalar_type());
|
||||
auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -50,7 +50,7 @@ nfnet_l0,pass,7
|
||||
|
||||
|
||||
|
||||
repvgg_a2,fail_accuracy,7
|
||||
repvgg_a2,pass,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -952,7 +952,7 @@ def latency_experiment_summary(suite_name, args, model, timings, **kwargs):
|
||||
first_fields.append(kwargs["tag"])
|
||||
headers = first_headers + ["speedup", "abs_latency"]
|
||||
row = first_fields + [float(speedup), median[1] * 1000]
|
||||
msg = f"{speedup:.3f}x"
|
||||
msg = f"{median[0] * 1000} ms, {median[1] * 1000} ms, {speedup:.3f}x"
|
||||
if args.baseline:
|
||||
headers.extend(
|
||||
[
|
||||
@ -1010,7 +1010,7 @@ def latency_experiment_summary(suite_name, args, model, timings, **kwargs):
|
||||
# Hypothetically you can use this from other places, but it's currently
|
||||
# inaccessible, and when this assert fails you need to update the
|
||||
# event_name here to account for the other cases you are using this
|
||||
assert args.quantization is not None
|
||||
assert any([args.quantization, args.optimus])
|
||||
output_signpost(
|
||||
dict(zip(headers, row)),
|
||||
args,
|
||||
@ -2587,6 +2587,9 @@ class BenchmarkRunner:
|
||||
**experiment_kwargs,
|
||||
)
|
||||
|
||||
# reset dynamo
|
||||
torch._dynamo.reset()
|
||||
|
||||
if self.args.export_aot_inductor:
|
||||
optimized_model_iter_fn = optimize_ctx
|
||||
else:
|
||||
@ -2950,7 +2953,7 @@ class BenchmarkRunner:
|
||||
status = self.check_tolerance(name, model, example_inputs, optimize_ctx)
|
||||
print(status)
|
||||
elif self.args.performance:
|
||||
if self.args.backend == "torchao":
|
||||
if self.args.backend in ["torchao", "optimus"]:
|
||||
status = self.run_performance_test_non_alternate(
|
||||
name, model, example_inputs, optimize_ctx, experiment, tag
|
||||
)
|
||||
@ -3526,6 +3529,12 @@ def parse_args(args=None):
|
||||
action="store_true",
|
||||
help="Measure speedup with TorchInductor",
|
||||
)
|
||||
group.add_argument(
|
||||
"--optimus",
|
||||
choices=["vertical_opt", "horizontal_opt", "all"],
|
||||
default=None,
|
||||
help="Measure speedup of Optimus with TorchInductor baseline",
|
||||
)
|
||||
group.add_argument(
|
||||
"--quantization",
|
||||
choices=[
|
||||
@ -3783,6 +3792,9 @@ def run(runner, args, original_dir=None):
|
||||
if args.inductor:
|
||||
assert args.backend is None
|
||||
args.backend = "inductor"
|
||||
if args.optimus:
|
||||
assert args.backend is None
|
||||
args.backend = "optimus"
|
||||
if args.quantization:
|
||||
assert args.backend is None
|
||||
args.backend = "torchao"
|
||||
@ -4067,10 +4079,22 @@ def run(runner, args, original_dir=None):
|
||||
|
||||
runner.model_iter_fn = model_iter_fn_and_mark_step
|
||||
optimize_ctx = torchao_optimize_ctx(args.quantization)
|
||||
elif args.backend == "optimus":
|
||||
from .optimus import get_baseline_ctx, get_optimus_optimize_ctx
|
||||
|
||||
baseline_ctx = get_baseline_ctx(
|
||||
nopython=args.nopython, inductor_compile_mode=args.inductor_compile_mode
|
||||
)
|
||||
runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
|
||||
optimize_ctx = get_optimus_optimize_ctx(
|
||||
args.optimus, args.nopython, args.inductor_compile_mode
|
||||
)
|
||||
else:
|
||||
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
|
||||
experiment = (
|
||||
speedup_experiment if args.backend != "torchao" else latency_experiment
|
||||
speedup_experiment
|
||||
if args.backend not in ["torchao", "optimus"]
|
||||
else latency_experiment
|
||||
)
|
||||
if args.accuracy:
|
||||
output_filename = f"accuracy_{args.backend}.csv"
|
||||
@ -4091,7 +4115,12 @@ def run(runner, args, original_dir=None):
|
||||
if args.only in runner.disable_cudagraph_models:
|
||||
args.disable_cudagraphs = True
|
||||
|
||||
if args.inductor or args.backend == "inductor" or args.export_aot_inductor:
|
||||
if (
|
||||
args.inductor
|
||||
or args.backend == "inductor"
|
||||
or args.export_aot_inductor
|
||||
or args.backend == "optimus"
|
||||
):
|
||||
inductor_config.triton.cudagraphs = not args.disable_cudagraphs
|
||||
inductor_config.triton.persistent_reductions = (
|
||||
not args.disable_persistent_reductions
|
||||
|
||||
62
benchmarks/dynamo/optimus.py
Normal file
62
benchmarks/dynamo/optimus.py
Normal file
@ -0,0 +1,62 @@
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_baseline_ctx(nopython, inductor_compile_mode):
|
||||
return functools.partial(
|
||||
torch.compile,
|
||||
backend="inductor",
|
||||
fullgraph=nopython,
|
||||
mode=inductor_compile_mode,
|
||||
)
|
||||
|
||||
|
||||
def get_optimus_optimize_ctx(config, nopython, inductor_compile_mode):
|
||||
if config == "vertical_opt":
|
||||
optimus_inductor_config = {
|
||||
"pre_grad_fusion_options": {
|
||||
"normalization_pass": {},
|
||||
"merge_splits_pass": {},
|
||||
"split_cat_pass": {},
|
||||
"unbind_stack_pass": {},
|
||||
"unbind_cat_to_view_pass": {},
|
||||
}
|
||||
}
|
||||
elif config == "horizontal_opt":
|
||||
optimus_inductor_config = {
|
||||
"pre_grad_fusion_options": {
|
||||
"normalization_pass": {},
|
||||
"batch_linear": {},
|
||||
"batch_layernorm": {},
|
||||
},
|
||||
}
|
||||
elif config == "all":
|
||||
optimus_inductor_config = {
|
||||
"pre_grad_fusion_options": {
|
||||
"normalization_pass": {},
|
||||
"batch_linear": {},
|
||||
"batch_layernorm": {},
|
||||
"merge_splits_pass": {},
|
||||
"split_cat_pass": {},
|
||||
"unbind_stack_pass": {},
|
||||
"unbind_cat_to_view_pass": {},
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Unknown optimus config: {config}")
|
||||
|
||||
def _inner(fn):
|
||||
if "pre_grad_fusion_options" in optimus_inductor_config:
|
||||
torch._inductor.config.pre_grad_fusion_options = optimus_inductor_config[
|
||||
"pre_grad_fusion_options"
|
||||
]
|
||||
if "post_grad_fusion_options" in optimus_inductor_config:
|
||||
torch._inductor.config.post_grad_fusion_options = optimus_inductor_config[
|
||||
"post_grad_fusion_options"
|
||||
]
|
||||
return torch.compile(
|
||||
fn, backend="inductor", fullgraph=nopython, mode=inductor_compile_mode
|
||||
)
|
||||
|
||||
return _inner
|
||||
@ -2,6 +2,7 @@ import csv
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# This script takes the logs produced by the benchmark scripts (e.g.,
|
||||
@ -15,8 +16,7 @@ import sys
|
||||
# This script is not very well written, feel free to rewrite it as necessary
|
||||
|
||||
assert len(sys.argv) == 2
|
||||
|
||||
full_log = open(sys.argv[1]).read()
|
||||
full_log = Path(sys.argv[1]).read_text()
|
||||
|
||||
# If the log contains a gist URL, extract it so we can include it in the CSV
|
||||
gist_url = ""
|
||||
|
||||
@ -484,24 +484,106 @@ PyTorch,sum,sum_R256_V512_dim0_contiguousTrue_cpu,short,False,50.954394,0.000000
|
||||
PyTorch,sum,sum_R256_V512_dim0_contiguousFalse_cpu,short,False,57.957757,0.000000
|
||||
PyTorch,sum,sum_R256_V512_dim1_contiguousTrue_cpu,short,False,53.592068,0.000000
|
||||
PyTorch,sum,sum_R256_V512_dim1_contiguousFalse_cpu,short,False,51.339726,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N16_cpu,short,False,7.040985,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N64_cpu,short,False,7.168604,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N128_cpu,short,False,7.434442,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N16_cpu,short,False,7.078318,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N64_cpu,short,False,7.426670,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N128_cpu,short,False,7.679027,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N16_cpu,short,False,7.281365,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N64_cpu,short,False,7.682783,0.000000
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N128_cpu,short,False,8.381938,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N16_cpu,short,False,7.039854,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N64_cpu,short,False,7.399855,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N128_cpu,short,False,7.715193,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N16_cpu,short,False,7.255140,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N64_cpu,short,False,7.753522,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N128_cpu,short,False,8.364281,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N16_cpu,short,False,7.476377,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N64_cpu,short,False,8.458564,0.000000
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N128_cpu,short,False,9.391939,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,0.927,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.uint8,short,False,6.261,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int8,short,False,6.351,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int16,short,False,6.177,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int32,short,False,6.333,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int64,short,False,6.588,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float16,short,False,8.117,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bfloat16,short,False,9.358,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float32,short,False,7.844,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float64,short,False,8.097,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bool,short,False,6.159,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,0.926,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int8,short,False,6.192,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int16,short,False,6.276,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,6.461,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int64,short,False,6.524,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float16,short,False,8.136,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bfloat16,short,False,6.854,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float32,short,False,6.446,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float64,short,False,6.829,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bool,short,False,6.088,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.uint8,short,False,6.059,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int8,short,False,0.922,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int16,short,False,6.263,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int32,short,False,6.330,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int64,short,False,6.688,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float16,short,False,8.176,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bfloat16,short,False,6.959,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float32,short,False,6.430,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float64,short,False,6.818,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bool,short,False,6.350,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.uint8,short,False,6.221,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int8,short,False,6.193,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int16,short,False,0.922,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int32,short,False,6.263,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int64,short,False,6.525,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float16,short,False,7.960,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bfloat16,short,False,6.801,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float32,short,False,6.594,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float64,short,False,7.089,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bool,short,False,6.498,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,6.358,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int8,short,False,6.390,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int16,short,False,6.415,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,0.925,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int64,short,False,6.657,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float16,short,False,7.954,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bfloat16,short,False,6.930,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float32,short,False,6.737,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float64,short,False,6.948,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bool,short,False,6.757,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.uint8,short,False,6.402,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int8,short,False,6.550,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int16,short,False,6.518,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int32,short,False,6.766,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int64,short,False,0.929,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float16,short,False,8.557,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bfloat16,short,False,9.045,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float32,short,False,7.672,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float64,short,False,7.276,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bool,short,False,6.414,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.uint8,short,False,7.736,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int8,short,False,7.889,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int16,short,False,8.170,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int32,short,False,7.783,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int64,short,False,7.743,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float16,short,False,0.927,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bfloat16,short,False,7.018,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float32,short,False,8.428,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float64,short,False,6.767,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bool,short,False,6.479,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.uint8,short,False,7.827,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int8,short,False,6.450,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int16,short,False,6.320,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int32,short,False,6.385,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int64,short,False,8.119,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float16,short,False,8.063,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bfloat16,short,False,0.925,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float32,short,False,8.629,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float64,short,False,6.638,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bool,short,False,6.425,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.uint8,short,False,7.803,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int8,short,False,6.502,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int16,short,False,6.429,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int32,short,False,6.549,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int64,short,False,7.749,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float16,short,False,7.301,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bfloat16,short,False,7.682,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,0.930,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float64,short,False,6.738,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bool,short,False,6.798,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.uint8,short,False,6.506,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int8,short,False,6.494,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int16,short,False,6.668,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int32,short,False,6.696,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int64,short,False,7.115,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float16,short,False,7.910,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bfloat16,short,False,7.410,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float32,short,False,6.868,0.000000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float64,short,False,0.924,0.000000
|
||||
PyTorch,addcmul,addcmul_M1_N2_cpu_dtypetorch.float32,short,False,4.461410,0.000000
|
||||
PyTorch,addcmul,addcmul_M1_N2_cpu_dtypetorch.bfloat16,short,False,4.560082,0.000000
|
||||
PyTorch,addcmul,addcmul_M32_N64_cpu_dtypetorch.float32,short,False,5.141248,0.000000
|
||||
|
||||
|
@ -4,74 +4,84 @@ import torch
|
||||
|
||||
|
||||
tensor_conversion_short_configs = op_bench.cross_product_configs(
|
||||
M=(
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
),
|
||||
N=(
|
||||
16,
|
||||
64,
|
||||
128,
|
||||
),
|
||||
M=[32],
|
||||
N=[128],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[
|
||||
torch.bool,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
torch.double,
|
||||
],
|
||||
dtype_two=[
|
||||
torch.bool,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
torch.double,
|
||||
],
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
tensor_conversion_long_configs = op_bench.cross_product_configs(
|
||||
M=(
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
),
|
||||
N=(
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
),
|
||||
M=[1024],
|
||||
N=[1024],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[
|
||||
torch.bool,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
torch.double,
|
||||
],
|
||||
dtype_two=[
|
||||
torch.bool,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
torch.double,
|
||||
],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class FloatToHalfTensorConversionBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, device):
|
||||
class TensorConversionBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, dtype_one, dtype_two, device):
|
||||
self.inputs = {
|
||||
"input": torch.rand(
|
||||
M, N, device=device, requires_grad=False, dtype=torch.float
|
||||
)
|
||||
).to(dtype=dtype_one)
|
||||
}
|
||||
self.dtype_one = dtype_one
|
||||
self.dtype_two = dtype_two
|
||||
|
||||
def forward(self, input):
|
||||
return input.to(torch.half)
|
||||
return input.to(dtype=self.dtype_two)
|
||||
|
||||
|
||||
class HalfToFloatTensorConversionBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, device):
|
||||
self.inputs = {
|
||||
"input": torch.rand(
|
||||
M, N, device=device, requires_grad=False, dtype=torch.half
|
||||
)
|
||||
}
|
||||
|
||||
def forward(self, input):
|
||||
return input.to(torch.float)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(
|
||||
tensor_conversion_short_configs, FloatToHalfTensorConversionBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(
|
||||
tensor_conversion_long_configs, FloatToHalfTensorConversionBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(
|
||||
tensor_conversion_short_configs, HalfToFloatTensorConversionBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(
|
||||
tensor_conversion_long_configs, HalfToFloatTensorConversionBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(tensor_conversion_short_configs, TensorConversionBenchmark)
|
||||
op_bench.generate_pt_test(tensor_conversion_long_configs, TensorConversionBenchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
|
||||
@ -349,24 +349,106 @@ PyTorch,sum,sum_R256_V512_dim0_contiguousTrue_cpu,short,FALSE,12.5841
|
||||
PyTorch,sum,sum_R256_V512_dim0_contiguousFALSE_cpu,short,FALSE,20.8765
|
||||
PyTorch,sum,sum_R256_V512_dim1_contiguousTrue_cpu,short,FALSE,15.4414
|
||||
PyTorch,sum,sum_R256_V512_dim1_contiguousFALSE_cpu,short,FALSE,15.3287
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N16_cpu,short,FALSE,5.0499
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N64_cpu,short,FALSE,5.3229
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N128_cpu,short,FALSE,5.4418
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N16_cpu,short,FALSE,5.0868
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N64_cpu,short,FALSE,5.4495
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N128_cpu,short,FALSE,5.5578
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N16_cpu,short,FALSE,5.2631
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N64_cpu,short,FALSE,5.5646
|
||||
PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N128_cpu,short,FALSE,5.7898
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N16_cpu,short,FALSE,5.0228
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N64_cpu,short,FALSE,5.3692
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N128_cpu,short,FALSE,5.4006
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N16_cpu,short,FALSE,5.1107
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N64_cpu,short,FALSE,5.4119
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N128_cpu,short,FALSE,5.5583
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N16_cpu,short,FALSE,5.3818
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N64_cpu,short,FALSE,5.5742
|
||||
PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N128_cpu,short,FALSE,6.8414
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,0.797
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.uint8,short,False,6.071
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int8,short,False,6.031
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int16,short,False,6.243
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int32,short,False,7.231
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int64,short,False,7.791
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float16,short,False,12.661
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bfloat16,short,False,11.225
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float32,short,False,9.772
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float64,short,False,9.872
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bool,short,False,6.033
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,0.781
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int8,short,False,6.060
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int16,short,False,6.180
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.258
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int64,short,False,7.758
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float16,short,False,10.504
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bfloat16,short,False,6.749
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float32,short,False,7.679
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float64,short,False,7.797
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bool,short,False,6.019
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.uint8,short,False,6.079
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int8,short,False,0.785
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int16,short,False,6.188
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int32,short,False,7.288
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int64,short,False,7.770
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float16,short,False,10.466
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bfloat16,short,False,6.676
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float32,short,False,7.736
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float64,short,False,7.780
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bool,short,False,6.130
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.uint8,short,False,6.221
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int8,short,False,6.101
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int16,short,False,0.791
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int32,short,False,6.254
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int64,short,False,7.733
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float16,short,False,10.562
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bfloat16,short,False,6.704
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float32,short,False,7.819
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float64,short,False,8.276
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bool,short,False,6.361
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,6.364
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int8,short,False,6.309
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int16,short,False,6.362
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,0.791
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int64,short,False,7.746
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float16,short,False,9.462
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bfloat16,short,False,6.678
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float32,short,False,7.827
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float64,short,False,8.200
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bool,short,False,6.925
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.uint8,short,False,6.947
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int8,short,False,6.962
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int16,short,False,6.906
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int32,short,False,7.664
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int64,short,False,0.782
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float16,short,False,10.528
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bfloat16,short,False,10.123
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float32,short,False,9.234
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float64,short,False,8.694
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bool,short,False,12.653
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.uint8,short,False,9.348
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int8,short,False,8.774
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int16,short,False,9.063
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int32,short,False,10.012
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int64,short,False,13.641
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float16,short,False,0.788
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bfloat16,short,False,13.757
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float32,short,False,7.170
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float64,short,False,12.511
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bool,short,False,6.516
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.uint8,short,False,8.539
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int8,short,False,6.483
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int16,short,False,6.468
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int32,short,False,7.752
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int64,short,False,9.868
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float16,short,False,10.556
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bfloat16,short,False,0.792
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float32,short,False,7.577
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float64,short,False,8.267
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bool,short,False,6.819
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.uint8,short,False,7.715
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int8,short,False,6.754
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int16,short,False,6.825
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int32,short,False,7.790
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int64,short,False,9.219
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float16,short,False,5.977
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bfloat16,short,False,7.069
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,0.794
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float64,short,False,8.301
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bool,short,False,7.401
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.uint8,short,False,7.843
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int8,short,False,7.117
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int16,short,False,7.170
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int32,short,False,8.000
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int64,short,False,9.284
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float16,short,False,7.179
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bfloat16,short,False,7.645
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float32,short,False,7.988
|
||||
PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float64,short,False,0.792
|
||||
PyTorch,relu,"relu_dims(3,4,5)_contigFALSE_inplaceFALSE_dtypetorch.quint8",short,FALSE,9.4657
|
||||
PyTorch,relu,"relu_dims(3,4,5)_contigFALSE_inplaceFALSE_dtypetorch.qint8",short,FALSE,9.4625
|
||||
PyTorch,relu,"relu_dims(3,4,5)_contigFALSE_inplaceFALSE_dtypetorch.qint32",short,FALSE,9.4165
|
||||
|
||||
|
@ -83,10 +83,13 @@ if __name__ == "__main__":
|
||||
|
||||
if args.outfile == "stdout":
|
||||
outfile = sys.stdout
|
||||
need_close = False
|
||||
elif args.outfile == "stderr":
|
||||
outfile = sys.stderr
|
||||
need_close = False
|
||||
else:
|
||||
outfile = open(args.outfile, "a")
|
||||
need_close = True
|
||||
|
||||
test_count = args.test_count
|
||||
m = args.m
|
||||
@ -147,3 +150,5 @@ if __name__ == "__main__":
|
||||
time,
|
||||
file=outfile,
|
||||
)
|
||||
if need_close:
|
||||
outfile.close()
|
||||
|
||||
@ -82,10 +82,13 @@ if __name__ == "__main__":
|
||||
|
||||
if args.outfile == "stdout":
|
||||
outfile = sys.stdout
|
||||
need_close = False
|
||||
elif args.outfile == "stderr":
|
||||
outfile = sys.stderr
|
||||
need_close = False
|
||||
else:
|
||||
outfile = open(args.outfile, "a")
|
||||
need_close = True
|
||||
|
||||
test_count = args.test_count
|
||||
m = args.m
|
||||
@ -132,3 +135,5 @@ if __name__ == "__main__":
|
||||
time_csr,
|
||||
file=outfile,
|
||||
)
|
||||
if need_close:
|
||||
outfile.close()
|
||||
|
||||
@ -179,10 +179,13 @@ if __name__ == "__main__":
|
||||
|
||||
if args.outfile == "stdout":
|
||||
outfile = sys.stdout
|
||||
need_close = False
|
||||
elif args.outfile == "stderr":
|
||||
outfile = sys.stderr
|
||||
need_close = False
|
||||
else:
|
||||
outfile = open(args.outfile, "a")
|
||||
need_close = True
|
||||
|
||||
ops = args.ops.split(",")
|
||||
|
||||
@ -434,3 +437,5 @@ if __name__ == "__main__":
|
||||
if op not in {"bsr_scatter_mm6", "bsr_dense_mm_with_meta"}:
|
||||
# Break on operations that do not consume parameters
|
||||
break
|
||||
if need_close:
|
||||
outfile.close()
|
||||
|
||||
@ -96,6 +96,10 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -345,6 +345,13 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<AllocatorState> pps) = 0;
|
||||
virtual std::string name() = 0;
|
||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return {free, total};
|
||||
}
|
||||
};
|
||||
|
||||
// Allocator object, statically initialized
|
||||
|
||||
111
c10/metal/error.h
Normal file
111
c10/metal/error.h
Normal file
@ -0,0 +1,111 @@
|
||||
#pragma once
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace metal {
|
||||
C10_METAL_CONSTEXPR unsigned error_message_count = 30;
|
||||
struct ErrorMessage {
|
||||
char file[128];
|
||||
char func[128];
|
||||
char message[250];
|
||||
unsigned int line;
|
||||
};
|
||||
|
||||
struct ErrorMessages {
|
||||
#ifdef __METAL__
|
||||
::metal::atomic<unsigned int> count;
|
||||
#else
|
||||
unsigned int count;
|
||||
#endif
|
||||
ErrorMessage msg[error_message_count];
|
||||
};
|
||||
|
||||
#ifdef __METAL__
|
||||
namespace detail {
|
||||
static uint strncpy(device char* dst, constant const char* src, unsigned len) {
|
||||
uint i = 0;
|
||||
while (src[i] != 0 && i < len - 1) {
|
||||
dst[i] = src[i];
|
||||
i++;
|
||||
}
|
||||
dst[i] = 0;
|
||||
return i;
|
||||
}
|
||||
|
||||
inline uint print_arg(
|
||||
device char* ptr,
|
||||
unsigned len,
|
||||
constant const char* arg) {
|
||||
return strncpy(ptr, arg, len);
|
||||
}
|
||||
|
||||
// Returns number length as string in base10
|
||||
static inline uint base10_length(long num) {
|
||||
uint rc = 1;
|
||||
if (num < 0) {
|
||||
num = -num;
|
||||
rc += 1;
|
||||
}
|
||||
while (num > 9) {
|
||||
num /= 10;
|
||||
rc++;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
// Converts signed integer to string
|
||||
inline uint print_arg(device char* ptr, unsigned len, long arg) {
|
||||
const auto arg_len = base10_length(arg);
|
||||
if (arg_len >= len)
|
||||
return 0;
|
||||
if (arg < 0) {
|
||||
ptr[0] = '-';
|
||||
arg = -arg;
|
||||
}
|
||||
uint idx = 1;
|
||||
do {
|
||||
ptr[arg_len - idx] = '0' + (arg % 10);
|
||||
arg /= 10;
|
||||
idx++;
|
||||
} while (arg > 0);
|
||||
ptr[arg_len] = 0;
|
||||
return arg_len;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void print_args(device char* ptr, unsigned len, T arg) {
|
||||
print_arg(ptr, len, arg);
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline void print_args(device char* ptr, unsigned len, T arg, Args... args) {
|
||||
const auto rc = print_arg(ptr, len, arg);
|
||||
print_args(ptr + rc, len - rc, args...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename... Args>
|
||||
static void report_error(
|
||||
device ErrorMessages* msgs,
|
||||
constant const char* file,
|
||||
int line,
|
||||
constant const char* func,
|
||||
Args... args) {
|
||||
const auto idx =
|
||||
atomic_fetch_add_explicit(&msgs->count, 1, ::metal::memory_order_relaxed);
|
||||
if (idx >= error_message_count) {
|
||||
return;
|
||||
}
|
||||
device auto* msg = &msgs->msg[idx];
|
||||
detail::strncpy(msg->file, file, 128);
|
||||
detail::strncpy(msg->func, func, 128);
|
||||
detail::print_args(msg->message, 250, args...);
|
||||
msg->line = line;
|
||||
}
|
||||
|
||||
#define TORCH_REPORT_ERROR(buf, ...) \
|
||||
::c10::metal::report_error(buf, __FILE__, __LINE__, __func__, __VA_ARGS__)
|
||||
#endif
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
@ -1 +0,0 @@
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
@ -1,224 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace c10::guts {
|
||||
|
||||
/**
|
||||
* Access information about result type or arguments from a function type.
|
||||
* Example:
|
||||
* using A = function_traits<int (float, double)>::return_type // A == int
|
||||
* using A = function_traits<int (float, double)>::parameter_types::tuple_type
|
||||
* // A == tuple<float, double>
|
||||
*/
|
||||
template <class Func>
|
||||
struct function_traits {
|
||||
static_assert(
|
||||
!std::is_same_v<Func, Func>,
|
||||
"In function_traits<Func>, Func must be a plain function type.");
|
||||
};
|
||||
template <class Result, class... Args>
|
||||
struct function_traits<Result(Args...)> {
|
||||
using func_type = Result(Args...);
|
||||
using return_type = Result;
|
||||
using parameter_types = typelist::typelist<Args...>;
|
||||
static constexpr auto number_of_parameters = sizeof...(Args);
|
||||
};
|
||||
|
||||
/**
|
||||
* infer_function_traits: creates a `function_traits` type for a simple
|
||||
* function (pointer) or functor (lambda/struct). Currently does not support
|
||||
* class methods.
|
||||
*/
|
||||
|
||||
template <typename Functor>
|
||||
struct infer_function_traits {
|
||||
using type = function_traits<
|
||||
c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
struct infer_function_traits<Result (*)(Args...)> {
|
||||
using type = function_traits<Result(Args...)>;
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
struct infer_function_traits<Result(Args...)> {
|
||||
using type = function_traits<Result(Args...)>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using infer_function_traits_t = typename infer_function_traits<T>::type;
|
||||
|
||||
/**
|
||||
* make_function_traits: creates a `function_traits` type given a Return type
|
||||
* and a typelist of Argument types
|
||||
*
|
||||
* Example:
|
||||
* bool f(int, int);
|
||||
*
|
||||
* infer_function_traits_t<f> == make_function_traits_t<bool,
|
||||
* typelist::typelist<int, int>>
|
||||
*/
|
||||
template <typename Result, typename ArgList>
|
||||
struct make_function_traits {
|
||||
static_assert(
|
||||
false_t<ArgList>::value,
|
||||
"In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>.");
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
struct make_function_traits<Result, typelist::typelist<Args...>> {
|
||||
using type = function_traits<Result(Args...)>;
|
||||
};
|
||||
|
||||
template <typename Result, typename ArgList>
|
||||
using make_function_traits_t =
|
||||
typename make_function_traits<Result, ArgList>::type;
|
||||
|
||||
/**
|
||||
* make_offset_index_sequence<Start, N>
|
||||
* Like make_index_sequence<N>, but starting from Start instead of 0.
|
||||
*
|
||||
* Example:
|
||||
* make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12>
|
||||
*/
|
||||
template <size_t Start, size_t N, size_t... Is>
|
||||
struct make_offset_index_sequence_impl
|
||||
: make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> {
|
||||
static_assert(
|
||||
static_cast<int>(Start) >= 0,
|
||||
"make_offset_index_sequence: Start < 0");
|
||||
static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0");
|
||||
};
|
||||
|
||||
template <size_t Start, size_t... Is>
|
||||
struct make_offset_index_sequence_impl<Start, 0, Is...> {
|
||||
typedef std::index_sequence<Is...> type;
|
||||
};
|
||||
|
||||
template <size_t Start, size_t N>
|
||||
using make_offset_index_sequence =
|
||||
typename make_offset_index_sequence_impl<Start, N>::type;
|
||||
|
||||
/**
|
||||
* Use tuple_elements to extract a position-indexed subset of elements
|
||||
* from the argument tuple into a result tuple.
|
||||
*
|
||||
* Example:
|
||||
* std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
|
||||
* std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0,
|
||||
* 2>());
|
||||
*/
|
||||
template <class Tuple, size_t... Is>
|
||||
constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...> /*unused*/) {
|
||||
return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* Use tuple_take to extract the first or last n elements from the argument
|
||||
* tuple into a result tuple.
|
||||
*
|
||||
* Example:
|
||||
* std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
|
||||
* std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t);
|
||||
* std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t);
|
||||
*/
|
||||
template <class Tuple, int N, class Enable = void>
|
||||
struct TupleTake {};
|
||||
|
||||
template <class Tuple, int N>
|
||||
struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> {
|
||||
static auto call(Tuple t) {
|
||||
constexpr size_t size = std::tuple_size<Tuple>();
|
||||
static_assert(N <= size, "tuple_take: N > size");
|
||||
return tuple_elements(t, std::make_index_sequence<N>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Tuple, int N>
|
||||
struct TupleTake < Tuple,
|
||||
N, std::enable_if_t<N<0, void>> {
|
||||
static auto call(Tuple t) {
|
||||
constexpr size_t size = std::tuple_size<Tuple>();
|
||||
static_assert(-N <= size, "tuple_take: -N > size");
|
||||
return tuple_elements(t, make_offset_index_sequence<size + N, -N>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Tuple, int N>
|
||||
auto tuple_take(Tuple t) {
|
||||
return TupleTake<Tuple, N>::call(t);
|
||||
}
|
||||
|
||||
/**
|
||||
* Use tuple_slice to extract a contiguous subtuple from the argument.
|
||||
*
|
||||
* Example:
|
||||
* std::tuple<int, const char*, double, bool> t = std::make_tuple(0,
|
||||
* "HEY", 2.0, false); std::tuple<int, const char*> middle_two =
|
||||
* tuple_slice<decltype(t), 1, 2>(t);
|
||||
*/
|
||||
template <class Tuple, size_t Start, size_t N>
|
||||
constexpr auto tuple_slice(Tuple t) {
|
||||
constexpr size_t size = std::tuple_size<Tuple>();
|
||||
static_assert(Start + N <= size, "tuple_slice: Start + N > size");
|
||||
return tuple_elements(t, make_offset_index_sequence<Start, N>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* Use tuple_map to run a mapping function over a tuple to get a new tuple.
|
||||
*
|
||||
* Example 1:
|
||||
* auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), []
|
||||
* (int32_t a) -> int16_t {return a+1;});
|
||||
* // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
|
||||
*
|
||||
* Example 2:
|
||||
* struct Mapper {
|
||||
* std::string operator()(int32_t a) const {
|
||||
* return std::to_string(a);
|
||||
* }
|
||||
* int64_t operator()(const std::string& a) const {
|
||||
* return atoi(a.c_str());
|
||||
* }
|
||||
* };
|
||||
* auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"),
|
||||
* Mapper());
|
||||
* // result == std::tuple<std::string, int64_t>("3", 4)
|
||||
*
|
||||
* Example 3:
|
||||
* struct A final {
|
||||
* int32_t func() {
|
||||
* return 5;
|
||||
* }
|
||||
* };
|
||||
* struct B final {
|
||||
* std::string func() {
|
||||
* return "5";
|
||||
* }
|
||||
* };
|
||||
* auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return
|
||||
* a.func(); });
|
||||
* // result == std::tuple<int32_t, std::string>(5, "5");
|
||||
*/
|
||||
namespace detail {
|
||||
template <class Mapper, class... Args, size_t... Indices>
|
||||
auto tuple_map(
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
|
||||
std::tuple<Args...>&& tuple,
|
||||
const Mapper& mapper,
|
||||
std::index_sequence<Indices...> /*unused*/) {
|
||||
return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(
|
||||
tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template <class Mapper, class... Args>
|
||||
auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
|
||||
return detail::tuple_map(
|
||||
std::move(tuple), mapper, std::index_sequence_for<Args...>());
|
||||
}
|
||||
|
||||
} // namespace c10::guts
|
||||
#include <torch/headeronly/util/Metaprogramming.h>
|
||||
|
||||
@ -1,515 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/TypeTraits.h>
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace c10::guts {
|
||||
|
||||
template <class... T>
|
||||
struct false_t : std::false_type {};
|
||||
template <template <class> class... T>
|
||||
struct false_higher_t : std::false_type {};
|
||||
|
||||
namespace typelist {
|
||||
|
||||
/**
|
||||
* Type holding a list of types for compile time type computations
|
||||
*/
|
||||
template <class... Items>
|
||||
struct typelist final {
|
||||
public:
|
||||
typelist() = delete; // not for instantiation
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns the number of types in a typelist
|
||||
* Example:
|
||||
* 3 == size<typelist<int, int, double>>::value
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct size final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::size<T>, T must be typelist<...>.");
|
||||
};
|
||||
template <class... Types>
|
||||
struct size<typelist<Types...>> final {
|
||||
static constexpr size_t value = sizeof...(Types);
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms a list of types into a tuple holding these types.
|
||||
* Example:
|
||||
* std::tuple<int, string> == to_tuple_t<typelist<int, string>>
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct to_tuple final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::to_tuple<T>, T must be typelist<...>.");
|
||||
};
|
||||
template <class... Types>
|
||||
struct to_tuple<typelist<Types...>> final {
|
||||
using type = std::tuple<Types...>;
|
||||
};
|
||||
template <class TypeList>
|
||||
using to_tuple_t = typename to_tuple<TypeList>::type;
|
||||
|
||||
/**
|
||||
* Creates a typelist containing the types of a given tuple.
|
||||
* Example:
|
||||
* typelist<int, string> == from_tuple_t<std::tuple<int, string>>
|
||||
*/
|
||||
template <class Tuple>
|
||||
struct from_tuple final {
|
||||
static_assert(
|
||||
false_t<Tuple>::value,
|
||||
"In typelist::from_tuple<T>, T must be std::tuple<...>.");
|
||||
};
|
||||
template <class... Types>
|
||||
struct from_tuple<std::tuple<Types...>> final {
|
||||
using type = typelist<Types...>;
|
||||
};
|
||||
template <class Tuple>
|
||||
using from_tuple_t = typename from_tuple<Tuple>::type;
|
||||
|
||||
/**
|
||||
* Concatenates multiple type lists.
|
||||
* Example:
|
||||
* typelist<int, string, int> == concat_t<typelist<int, string>,
|
||||
* typelist<int>>
|
||||
*/
|
||||
template <class... TypeLists>
|
||||
struct concat final {
|
||||
static_assert(
|
||||
false_t<TypeLists...>::value,
|
||||
"In typelist::concat<T1, ...>, the T arguments each must be typelist<...>.");
|
||||
};
|
||||
template <class... Head1Types, class... Head2Types, class... TailLists>
|
||||
struct concat<typelist<Head1Types...>, typelist<Head2Types...>, TailLists...>
|
||||
final {
|
||||
using type =
|
||||
typename concat<typelist<Head1Types..., Head2Types...>, TailLists...>::
|
||||
type;
|
||||
};
|
||||
template <class... HeadTypes>
|
||||
struct concat<typelist<HeadTypes...>> final {
|
||||
using type = typelist<HeadTypes...>;
|
||||
};
|
||||
template <>
|
||||
struct concat<> final {
|
||||
using type = typelist<>;
|
||||
};
|
||||
template <class... TypeLists>
|
||||
using concat_t = typename concat<TypeLists...>::type;
|
||||
|
||||
/**
|
||||
* Filters the types in a type list by a type trait.
|
||||
* Examples:
|
||||
* typelist<int&, const string&&> == filter_t<std::is_reference,
|
||||
* typelist<void, string, int&, bool, const string&&, int>>
|
||||
*/
|
||||
template <template <class> class Condition, class TypeList>
|
||||
struct filter final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::filter<Condition, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Condition, class Head, class... Tail>
|
||||
struct filter<Condition, typelist<Head, Tail...>> final {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::filter<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
using type = std::conditional_t<
|
||||
Condition<Head>::value,
|
||||
concat_t<
|
||||
typelist<Head>,
|
||||
typename filter<Condition, typelist<Tail...>>::type>,
|
||||
typename filter<Condition, typelist<Tail...>>::type>;
|
||||
};
|
||||
template <template <class> class Condition>
|
||||
struct filter<Condition, typelist<>> final {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::filter<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
using type = typelist<>;
|
||||
};
|
||||
template <template <class> class Condition, class TypeList>
|
||||
using filter_t = typename filter<Condition, TypeList>::type;
|
||||
|
||||
/**
|
||||
* Counts how many types in the list fulfill a type trait
|
||||
* Examples:
|
||||
* 2 == count_if<std::is_reference, typelist<void, string, int&, bool, const
|
||||
* string&&, int>>
|
||||
*/
|
||||
template <template <class> class Condition, class TypeList>
|
||||
struct count_if final {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::count_if<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
static_assert(
|
||||
is_instantiation_of<typelist, TypeList>::value,
|
||||
"In typelist::count_if<Condition, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
// TODO Direct implementation might be faster
|
||||
static constexpr size_t value = size<filter_t<Condition, TypeList>>::value;
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if a typelist contains a certain type.
|
||||
* Examples:
|
||||
* contains<typelist<int, string>, string> == true_type
|
||||
* contains<typelist<int, string>, double> == false_type
|
||||
*/
|
||||
namespace detail {
|
||||
template <class TypeList, class Type, class Enable = void>
|
||||
struct contains {};
|
||||
template <class Type>
|
||||
struct contains<typelist<>, Type, void> : std::false_type {};
|
||||
template <class Type, class Head, class... Tail>
|
||||
struct contains<
|
||||
typelist<Head, Tail...>,
|
||||
Type,
|
||||
std::enable_if_t<std::is_same_v<Head, Type>>> : std::true_type {};
|
||||
template <class Type, class Head, class... Tail>
|
||||
struct contains<
|
||||
typelist<Head, Tail...>,
|
||||
Type,
|
||||
std::enable_if_t<!std::is_same_v<Head, Type>>>
|
||||
: contains<typelist<Tail...>, Type> {};
|
||||
} // namespace detail
|
||||
template <class TypeList, class Type>
|
||||
using contains = typename detail::contains<TypeList, Type>::type;
|
||||
|
||||
/**
|
||||
* Returns true iff the type trait is true for all types in the type list
|
||||
* Examples:
|
||||
* true == all<std::is_reference, typelist<int&, const float&&, const
|
||||
* MyClass&>>::value false == all<std::is_reference, typelist<int&, const
|
||||
* float&&, MyClass>>::value
|
||||
*/
|
||||
template <template <class> class Condition, class TypeList>
|
||||
struct all {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::all<Condition, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Condition, class... Types>
|
||||
struct all<Condition, typelist<Types...>>
|
||||
: std::conjunction<Condition<Types>...> {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::all<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns true iff the type trait is true for any type in the type list
|
||||
* Examples:
|
||||
* true == true_for_any_type<std::is_reference, typelist<int, const
|
||||
* float&&, const MyClass>>::value false ==
|
||||
* true_for_any_type<std::is_reference, typelist<int, const float,
|
||||
* MyClass>>::value
|
||||
*/
|
||||
template <template <class> class Condition, class TypeList>
|
||||
struct true_for_any_type final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::true_for_any_type<Condition, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Condition, class... Types>
|
||||
struct true_for_any_type<Condition, typelist<Types...>> final
|
||||
: std::disjunction<Condition<Types>...> {
|
||||
static_assert(
|
||||
is_type_condition<Condition>::value,
|
||||
"In typelist::true_for_any_type<Condition, TypeList>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
|
||||
};
|
||||
|
||||
/**
|
||||
* Maps types of a type list using a type trait
|
||||
* Example:
|
||||
* typelist<int&, double&, string&> == map_t<std::add_lvalue_reference_t,
|
||||
* typelist<int, double, string>>
|
||||
*/
|
||||
template <template <class> class Mapper, class TypeList>
|
||||
struct map final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::map<Mapper, TypeList>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Mapper, class... Types>
|
||||
struct map<Mapper, typelist<Types...>> final {
|
||||
using type = typelist<Mapper<Types>...>;
|
||||
};
|
||||
template <template <class> class Mapper, class TypeList>
|
||||
using map_t = typename map<Mapper, TypeList>::type;
|
||||
|
||||
/**
|
||||
* Returns the first element of a type list.
|
||||
* Example:
|
||||
* int == head_t<typelist<int, string>>
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct head final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::head<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
template <class Head, class... Tail>
|
||||
struct head<typelist<Head, Tail...>> final {
|
||||
using type = Head;
|
||||
};
|
||||
template <class TypeList>
|
||||
using head_t = typename head<TypeList>::type;
|
||||
|
||||
/**
|
||||
* Returns the first element of a type list, or the specified default if the
|
||||
* type list is empty. Example: int == head_t<bool, typelist<int, string>>
|
||||
* bool == head_t<bool, typelist<>>
|
||||
*/
|
||||
template <class Default, class TypeList>
|
||||
struct head_with_default final {
|
||||
using type = Default;
|
||||
};
|
||||
template <class Default, class Head, class... Tail>
|
||||
struct head_with_default<Default, typelist<Head, Tail...>> final {
|
||||
using type = Head;
|
||||
};
|
||||
template <class Default, class TypeList>
|
||||
using head_with_default_t = typename head_with_default<Default, TypeList>::type;
|
||||
|
||||
/**
|
||||
* Returns the N-th element of a type list.
|
||||
* Example:
|
||||
* int == element_t<1, typelist<float, int, char>>
|
||||
*/
|
||||
|
||||
/// Base template.
|
||||
template <size_t Index, class TypeList>
|
||||
struct element final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::element<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
|
||||
/// Successful case, we have reached the zero index and can "return" the head
|
||||
/// type.
|
||||
template <class Head, class... Tail>
|
||||
struct element<0, typelist<Head, Tail...>> {
|
||||
using type = Head;
|
||||
};
|
||||
|
||||
/// Error case, we have an index but ran out of types! It will only be selected
|
||||
/// if `Ts...` is actually empty!
|
||||
template <size_t Index, class... Ts>
|
||||
struct element<Index, typelist<Ts...>> {
|
||||
static_assert(
|
||||
Index < sizeof...(Ts),
|
||||
"Index is out of bounds in typelist::element");
|
||||
};
|
||||
|
||||
/// Shave off types until we hit the <0, Head, Tail...> or <Index> case.
|
||||
template <size_t Index, class Head, class... Tail>
|
||||
struct element<Index, typelist<Head, Tail...>>
|
||||
: element<Index - 1, typelist<Tail...>> {};
|
||||
|
||||
/// Convenience alias.
|
||||
template <size_t Index, class TypeList>
|
||||
using element_t = typename element<Index, TypeList>::type;
|
||||
|
||||
/**
|
||||
* Returns the last element of a type list.
|
||||
* Example:
|
||||
* int == last_t<typelist<int, string>>
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct last final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::last<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
template <class Head, class... Tail>
|
||||
struct last<typelist<Head, Tail...>> final {
|
||||
using type = typename last<typelist<Tail...>>::type;
|
||||
};
|
||||
template <class Head>
|
||||
struct last<typelist<Head>> final {
|
||||
using type = Head;
|
||||
};
|
||||
template <class TypeList>
|
||||
using last_t = typename last<TypeList>::type;
|
||||
static_assert(std::is_same_v<int, last_t<typelist<double, float, int>>>);
|
||||
|
||||
/**
|
||||
* Take/drop a number of arguments from a typelist.
|
||||
* Example:
|
||||
* typelist<int, string> == take_t<typelist<int, string, bool>, 2>
|
||||
* typelist<bool> == drop_t<typelist<int, string, bool>, 2>
|
||||
*/
|
||||
namespace detail {
|
||||
template <class TypeList, size_t offset, class IndexSequence>
|
||||
struct take_elements final {};
|
||||
|
||||
template <class TypeList, size_t offset, size_t... Indices>
|
||||
struct take_elements<TypeList, offset, std::index_sequence<Indices...>> final {
|
||||
using type = typelist<typename element<offset + Indices, TypeList>::type...>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <class TypeList, size_t num>
|
||||
struct take final {
|
||||
static_assert(
|
||||
is_instantiation_of<typelist, TypeList>::value,
|
||||
"In typelist::take<T, num>, the T argument must be typelist<...>.");
|
||||
static_assert(
|
||||
num <= size<TypeList>::value,
|
||||
"Tried to typelist::take more elements than there are in the list");
|
||||
using type = typename detail::
|
||||
take_elements<TypeList, 0, std::make_index_sequence<num>>::type;
|
||||
};
|
||||
template <class TypeList, size_t num>
|
||||
using take_t = typename take<TypeList, num>::type;
|
||||
|
||||
template <class TypeList, size_t num>
|
||||
struct drop final {
|
||||
static_assert(
|
||||
is_instantiation_of<typelist, TypeList>::value,
|
||||
"In typelist::drop<T, num>, the T argument must be typelist<...>.");
|
||||
static_assert(
|
||||
num <= size<TypeList>::value,
|
||||
"Tried to typelist::drop more elements than there are in the list");
|
||||
using type = typename detail::take_elements<
|
||||
TypeList,
|
||||
num,
|
||||
std::make_index_sequence<size<TypeList>::value - num>>::type;
|
||||
};
|
||||
template <class TypeList, size_t num>
|
||||
using drop_t = typename drop<TypeList, num>::type;
|
||||
|
||||
/**
|
||||
* Like drop, but returns an empty list rather than an assertion error if `num`
|
||||
* is larger than the size of the TypeList.
|
||||
* Example:
|
||||
* typelist<> == drop_if_nonempty_t<typelist<string, bool>, 2>
|
||||
* typelist<> == drop_if_nonempty_t<typelist<int, string, bool>, 3>
|
||||
*/
|
||||
template <class TypeList, size_t num>
|
||||
struct drop_if_nonempty final {
|
||||
static_assert(
|
||||
is_instantiation_of<typelist, TypeList>::value,
|
||||
"In typelist::drop<T, num>, the T argument must be typelist<...>.");
|
||||
using type = typename detail::take_elements<
|
||||
TypeList,
|
||||
std::min(num, size<TypeList>::value),
|
||||
std::make_index_sequence<
|
||||
size<TypeList>::value - std::min(num, size<TypeList>::value)>>::type;
|
||||
};
|
||||
template <class TypeList, size_t num>
|
||||
using drop_if_nonempty_t = typename drop_if_nonempty<TypeList, num>::type;
|
||||
|
||||
/**
|
||||
* Reverses a typelist.
|
||||
* Example:
|
||||
* typelist<int, string> == reverse_t<typelist<string, int>>
|
||||
*/
|
||||
template <class TypeList>
|
||||
struct reverse final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::reverse<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
template <class Head, class... Tail>
|
||||
struct reverse<typelist<Head, Tail...>> final {
|
||||
using type =
|
||||
concat_t<typename reverse<typelist<Tail...>>::type, typelist<Head>>;
|
||||
};
|
||||
template <>
|
||||
struct reverse<typelist<>> final {
|
||||
using type = typelist<>;
|
||||
};
|
||||
template <class TypeList>
|
||||
using reverse_t = typename reverse<TypeList>::type;
|
||||
|
||||
/**
|
||||
* Find the index of the first type in a typelist fulfilling a type trait
|
||||
* condition. Example:
|
||||
*
|
||||
* 2 == find_if<typelist<char, int, char&, int&>, std::is_reference>::value
|
||||
*/
|
||||
template <class TypeList, template <class> class Condition, class Enable = void>
|
||||
struct find_if final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::find_if<TypeList, Condition>, the TypeList argument must be typelist<...>.");
|
||||
};
|
||||
template <template <class> class Condition>
|
||||
struct find_if<typelist<>, Condition, void> final {
|
||||
static_assert(
|
||||
false_higher_t<Condition>::value,
|
||||
"In typelist::find_if<Type/List, Condition>, didn't find any type fulfilling the Condition.");
|
||||
};
|
||||
template <class Head, class... Tail, template <class> class Condition>
|
||||
struct find_if<
|
||||
typelist<Head, Tail...>,
|
||||
Condition,
|
||||
std::enable_if_t<Condition<Head>::value>>
|
||||
final {
|
||||
static constexpr size_t value = 0;
|
||||
};
|
||||
template <class Head, class... Tail, template <class> class Condition>
|
||||
struct find_if<
|
||||
typelist<Head, Tail...>,
|
||||
Condition,
|
||||
std::enable_if_t<!Condition<Head>::value>>
|
||||
final {
|
||||
static constexpr size_t value =
|
||||
1 + find_if<typelist<Tail...>, Condition>::value;
|
||||
};
|
||||
|
||||
/**
|
||||
* Maps a list of types into a list of values.
|
||||
* Examples:
|
||||
* // Example 1
|
||||
* auto sizes =
|
||||
* map_types_to_values<typelist<int64_t, bool, uint32_t>>(
|
||||
* [] (auto t) { return sizeof(decltype(t)::type); }
|
||||
* );
|
||||
* // sizes == std::tuple<size_t, size_t, size_t>{8, 1, 4}
|
||||
*
|
||||
* // Example 2
|
||||
* auto shared_ptrs =
|
||||
* map_types_to_values<typelist<int, double>>(
|
||||
* [] (auto t) { return make_shared<typename decltype(t)::type>(); }
|
||||
* );
|
||||
* // shared_ptrs == std::tuple<shared_ptr<int>, shared_ptr<double>>()
|
||||
*/
|
||||
namespace detail {
|
||||
template <class T>
|
||||
struct type_ final {
|
||||
using type = T;
|
||||
};
|
||||
template <class TypeList>
|
||||
struct map_types_to_values final {
|
||||
static_assert(
|
||||
false_t<TypeList>::value,
|
||||
"In typelist::map_types_to_values<T>, the T argument must be typelist<...>.");
|
||||
};
|
||||
template <class... Types>
|
||||
struct map_types_to_values<typelist<Types...>> final {
|
||||
template <class Func>
|
||||
static auto call(Func&& func) {
|
||||
return std::tuple{std::forward<Func>(func)(type_<Types>())...};
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <class TypeList, class Func>
|
||||
auto map_types_to_values(Func&& func) {
|
||||
return detail::map_types_to_values<TypeList>::call(std::forward<Func>(func));
|
||||
}
|
||||
|
||||
} // namespace typelist
|
||||
} // namespace c10::guts
|
||||
#include <torch/headeronly/util/TypeList.h>
|
||||
|
||||
@ -1,151 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace c10::guts {
|
||||
|
||||
/**
|
||||
* is_equality_comparable<T> is true_type iff the equality operator is defined
|
||||
* for T.
|
||||
*/
|
||||
template <class T, class Enable = void>
|
||||
struct is_equality_comparable : std::false_type {};
|
||||
template <class T>
|
||||
struct is_equality_comparable<
|
||||
T,
|
||||
std::void_t<decltype(std::declval<T&>() == std::declval<T&>())>>
|
||||
: std::true_type {};
|
||||
template <class T>
|
||||
using is_equality_comparable_t = typename is_equality_comparable<T>::type;
|
||||
|
||||
/**
|
||||
* is_hashable<T> is true_type iff std::hash is defined for T
|
||||
*/
|
||||
template <class T, class Enable = void>
|
||||
struct is_hashable : std::false_type {};
|
||||
template <class T>
|
||||
struct is_hashable<T, std::void_t<decltype(std::hash<T>()(std::declval<T&>()))>>
|
||||
: std::true_type {};
|
||||
template <class T>
|
||||
using is_hashable_t = typename is_hashable<T>::type;
|
||||
|
||||
/**
|
||||
* is_function_type<T> is true_type iff T is a plain function type (i.e.
|
||||
* "Result(Args...)")
|
||||
*/
|
||||
template <class T>
|
||||
struct is_function_type : std::false_type {};
|
||||
template <class Result, class... Args>
|
||||
struct is_function_type<Result(Args...)> : std::true_type {};
|
||||
template <class T>
|
||||
using is_function_type_t = typename is_function_type<T>::type;
|
||||
|
||||
/**
|
||||
* is_instantiation_of<T, I> is true_type iff I is a template instantiation of T
|
||||
* (e.g. vector<int> is an instantiation of vector) Example:
|
||||
* is_instantiation_of_t<vector, vector<int>> // true
|
||||
* is_instantiation_of_t<pair, pair<int, string>> // true
|
||||
* is_instantiation_of_t<vector, pair<int, string>> // false
|
||||
*/
|
||||
template <template <class...> class Template, class T>
|
||||
struct is_instantiation_of : std::false_type {};
|
||||
template <template <class...> class Template, class... Args>
|
||||
struct is_instantiation_of<Template, Template<Args...>> : std::true_type {};
|
||||
template <template <class...> class Template, class T>
|
||||
using is_instantiation_of_t = typename is_instantiation_of<Template, T>::type;
|
||||
|
||||
namespace detail {
|
||||
/**
|
||||
* strip_class: helper to remove the class type from pointers to `operator()`.
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
struct strip_class {};
|
||||
template <typename Class, typename Result, typename... Args>
|
||||
struct strip_class<Result (Class::*)(Args...)> {
|
||||
using type = Result(Args...);
|
||||
};
|
||||
template <typename Class, typename Result, typename... Args>
|
||||
struct strip_class<Result (Class::*)(Args...) const> {
|
||||
using type = Result(Args...);
|
||||
};
|
||||
template <typename T>
|
||||
using strip_class_t = typename strip_class<T>::type;
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* Evaluates to true_type, iff the given class is a Functor
|
||||
* (i.e. has a call operator with some set of arguments)
|
||||
*/
|
||||
|
||||
template <class Functor, class Enable = void>
|
||||
struct is_functor : std::false_type {};
|
||||
template <class Functor>
|
||||
struct is_functor<
|
||||
Functor,
|
||||
std::enable_if_t<is_function_type<
|
||||
detail::strip_class_t<decltype(&Functor::operator())>>::value>>
|
||||
: std::true_type {};
|
||||
|
||||
/**
|
||||
* lambda_is_stateless<T> is true iff the lambda type T is stateless
|
||||
* (i.e. does not have a closure).
|
||||
* Example:
|
||||
* auto stateless_lambda = [] (int a) {return a;};
|
||||
* lambda_is_stateless<decltype(stateless_lambda)> // true
|
||||
* auto stateful_lambda = [&] (int a) {return a;};
|
||||
* lambda_is_stateless<decltype(stateful_lambda)> // false
|
||||
*/
|
||||
namespace detail {
|
||||
template <class LambdaType, class FuncType>
|
||||
struct is_stateless_lambda__ final {
|
||||
static_assert(
|
||||
!std::is_same_v<LambdaType, LambdaType>,
|
||||
"Base case shouldn't be hit");
|
||||
};
|
||||
// implementation idea: According to the C++ standard, stateless lambdas are
|
||||
// convertible to function pointers
|
||||
template <class LambdaType, class C, class Result, class... Args>
|
||||
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...) const>
|
||||
: std::is_convertible<LambdaType, Result (*)(Args...)> {};
|
||||
template <class LambdaType, class C, class Result, class... Args>
|
||||
struct is_stateless_lambda__<LambdaType, Result (C::*)(Args...)>
|
||||
: std::is_convertible<LambdaType, Result (*)(Args...)> {};
|
||||
|
||||
// case where LambdaType is not even a functor
|
||||
template <class LambdaType, class Enable = void>
|
||||
struct is_stateless_lambda_ final : std::false_type {};
|
||||
// case where LambdaType is a functor
|
||||
template <class LambdaType>
|
||||
struct is_stateless_lambda_<
|
||||
LambdaType,
|
||||
std::enable_if_t<is_functor<LambdaType>::value>>
|
||||
: is_stateless_lambda__<LambdaType, decltype(&LambdaType::operator())> {};
|
||||
} // namespace detail
|
||||
template <class T>
|
||||
using is_stateless_lambda = detail::is_stateless_lambda_<std::decay_t<T>>;
|
||||
|
||||
/**
|
||||
* is_type_condition<C> is true_type iff C<...> is a type trait representing a
|
||||
* condition (i.e. has a constexpr static bool ::value member) Example:
|
||||
* is_type_condition<std::is_reference> // true
|
||||
*/
|
||||
template <template <class> class C, class Enable = void>
|
||||
struct is_type_condition : std::false_type {};
|
||||
template <template <class> class C>
|
||||
struct is_type_condition<
|
||||
C,
|
||||
std::enable_if_t<
|
||||
std::is_same_v<bool, std::remove_cv_t<decltype(C<int>::value)>>>>
|
||||
: std::true_type {};
|
||||
|
||||
/**
|
||||
* is_fundamental<T> is true_type iff the lambda type T is a fundamental type
|
||||
* (that is, arithmetic type, void, or nullptr_t). Example: is_fundamental<int>
|
||||
* // true We define it here to resolve a MSVC bug. See
|
||||
* https://github.com/pytorch/pytorch/issues/30932 for details.
|
||||
*/
|
||||
template <class T>
|
||||
struct is_fundamental : std::is_fundamental<T> {};
|
||||
} // namespace c10::guts
|
||||
#include <torch/headeronly/util/TypeTraits.h>
|
||||
|
||||
@ -926,15 +926,14 @@ class DeviceCachingAllocator {
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
const auto device_total =
|
||||
raw_device.get_info<sycl::info::device::global_mem_size>();
|
||||
// Estimate the available device memory when the SYCL runtime does not
|
||||
// support the corresponding aspect (ext_intel_free_memory).
|
||||
size_t device_free = device_prop.global_mem_size -
|
||||
size_t device_free = device_total -
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||
// affected devices.
|
||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||
@ -1052,21 +1051,37 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo() {
|
||||
const auto& device = c10::xpu::get_raw_device(device_index);
|
||||
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
device.get_info<sycl::info::device::name>(),
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
const size_t free =
|
||||
device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
static_cast<double>(device_total);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||
set_fraction = true;
|
||||
}
|
||||
@ -1240,6 +1255,11 @@ class XPUAllocator : public DeviceAllocator {
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryInfo();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
|
||||
@ -478,6 +478,7 @@ function(torch_update_find_cuda_flags)
|
||||
endfunction()
|
||||
|
||||
include(CheckCXXCompilerFlag)
|
||||
include(CheckCCompilerFlag)
|
||||
include(CheckLinkerFlag)
|
||||
|
||||
##############################################################################
|
||||
@ -501,6 +502,24 @@ function(append_cxx_flag_if_supported flag outputvar)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(append_c_flag_if_supported flag outputvar)
|
||||
string(TOUPPER "HAS${flag}" _FLAG_NAME)
|
||||
string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}")
|
||||
|
||||
# GCC silences unknown -Wno-XXX flags, so test the corresponding -WXXX.
|
||||
if(CMAKE_C_COMPILER_ID STREQUAL "GNU")
|
||||
string(REGEX REPLACE "^Wno-" "W" new_flag "${flag}")
|
||||
else()
|
||||
set(new_flag "${flag}")
|
||||
endif()
|
||||
|
||||
check_c_compiler_flag("${new_flag}" ${_FLAG_NAME})
|
||||
if(${_FLAG_NAME})
|
||||
string(APPEND ${outputvar} " ${flag}")
|
||||
set(${outputvar} "${${outputvar}}" PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(target_compile_options_if_supported target flag)
|
||||
set(_compile_options "")
|
||||
append_cxx_flag_if_supported("${flag}" _compile_options)
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_memory_info
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
memory_allocated
|
||||
|
||||
@ -382,20 +382,6 @@ coverage_ignore_functions = [
|
||||
# torch.ao.quantization.backend_config.tensorrt
|
||||
"get_tensorrt_backend_config",
|
||||
"get_tensorrt_backend_config_dict",
|
||||
# torch.ao.quantization.backend_config.utils
|
||||
"entry_to_pretty_str",
|
||||
"get_fused_module_classes",
|
||||
"get_fuser_method_mapping",
|
||||
"get_fusion_pattern_to_extra_inputs_getter",
|
||||
"get_fusion_pattern_to_root_node_getter",
|
||||
"get_module_to_qat_module",
|
||||
"get_pattern_to_dtype_configs",
|
||||
"get_pattern_to_input_type_to_index",
|
||||
"get_qat_module_classes",
|
||||
"get_root_module_to_quantized_reference_module",
|
||||
"pattern_to_human_readable",
|
||||
"remove_boolean_dispatch_from_name",
|
||||
# torch.ao.quantization.backend_config.x86
|
||||
"get_x86_backend_config",
|
||||
# torch.ao.quantization.fuse_modules
|
||||
"fuse_known_modules",
|
||||
@ -426,25 +412,6 @@ coverage_ignore_functions = [
|
||||
"insert_observers_for_model",
|
||||
"prepare",
|
||||
"propagate_dtypes_for_known_nodes",
|
||||
# torch.ao.quantization.fx.utils
|
||||
"all_node_args_except_first",
|
||||
"all_node_args_have_no_tensors",
|
||||
"assert_and_get_unique_device",
|
||||
"collect_producer_nodes",
|
||||
"create_getattr_from_value",
|
||||
"create_node_from_old_node_preserve_meta",
|
||||
"get_custom_module_class_keys",
|
||||
"get_linear_prepack_op_for_dtype",
|
||||
"get_new_attr_name_with_prefix",
|
||||
"get_non_observable_arg_indexes_and_types",
|
||||
"get_qconv_prepack_op",
|
||||
"get_skipped_module_name_and_classes",
|
||||
"graph_module_from_producer_nodes",
|
||||
"maybe_get_next_module",
|
||||
"node_arg_is_bias",
|
||||
"node_arg_is_weight",
|
||||
"return_arg_list",
|
||||
# torch.ao.quantization.pt2e.graph_utils
|
||||
"bfs_trace_with_node_process",
|
||||
"find_sequential_partitions",
|
||||
"get_equivalent_types",
|
||||
@ -860,80 +827,10 @@ coverage_ignore_functions = [
|
||||
"get_latency_of_one_partition",
|
||||
"get_latency_of_partitioned_graph",
|
||||
"get_partition_to_latency_mapping",
|
||||
# torch.fx.experimental.proxy_tensor
|
||||
"decompose",
|
||||
"disable_autocast_cache",
|
||||
"disable_proxy_modes_tracing",
|
||||
"dispatch_trace",
|
||||
"extract_val",
|
||||
"fake_signature",
|
||||
"fetch_sym_proxy",
|
||||
"fetch_object_proxy",
|
||||
"get_innermost_proxy_mode",
|
||||
"get_isolated_graphmodule",
|
||||
"get_proxy_slot",
|
||||
"get_torch_dispatch_modes",
|
||||
"has_proxy_slot",
|
||||
"is_sym_node",
|
||||
"maybe_handle_decomp",
|
||||
"proxy_call",
|
||||
"set_meta",
|
||||
"set_original_aten_op",
|
||||
"set_proxy_slot",
|
||||
"snapshot_fake",
|
||||
"thunkify",
|
||||
"track_tensor",
|
||||
"track_tensor_tree",
|
||||
"wrap_key",
|
||||
"wrapper_and_args_for_make_fx",
|
||||
# torch.fx.experimental.recording
|
||||
"record_shapeenv_event",
|
||||
"replay_shape_env_events",
|
||||
"shape_env_check_state_equal",
|
||||
# torch.fx.experimental.sym_node
|
||||
"ceil_impl",
|
||||
"floor_ceil_helper",
|
||||
"floor_impl",
|
||||
"method_to_operator",
|
||||
"sympy_is_channels_last_contiguous_2d",
|
||||
"sympy_is_channels_last_contiguous_3d",
|
||||
"sympy_is_channels_last_strides_2d",
|
||||
"sympy_is_channels_last_strides_3d",
|
||||
"sympy_is_channels_last_strides_generic",
|
||||
"sympy_is_contiguous",
|
||||
"sympy_is_contiguous_generic",
|
||||
"to_node",
|
||||
"wrap_node",
|
||||
"sym_sqrt",
|
||||
# torch.fx.experimental.symbolic_shapes
|
||||
"bind_symbols",
|
||||
"cast_symbool_to_symint_guardless",
|
||||
"create_contiguous",
|
||||
"error",
|
||||
"eval_guards",
|
||||
"eval_is_non_overlapping_and_dense",
|
||||
"expect_true",
|
||||
"find_symbol_binding_fx_nodes",
|
||||
"free_symbols",
|
||||
"free_unbacked_symbols",
|
||||
"fx_placeholder_targets",
|
||||
"fx_placeholder_vals",
|
||||
"guard_bool",
|
||||
"guard_float",
|
||||
"guard_int",
|
||||
"guard_scalar",
|
||||
"has_hint",
|
||||
"has_symbolic_sizes_strides",
|
||||
"is_channels_last_contiguous_2d",
|
||||
"is_channels_last_contiguous_3d",
|
||||
"is_channels_last_strides_2d",
|
||||
"is_channels_last_strides_3d",
|
||||
"is_contiguous",
|
||||
"is_non_overlapping_and_dense_indicator",
|
||||
"is_nested_int",
|
||||
"is_symbol_binding_fx_node",
|
||||
"is_symbolic",
|
||||
# torch.fx.experimental.unification.core
|
||||
"reify",
|
||||
# torch.fx.experimental.unification.match
|
||||
"edge",
|
||||
@ -971,24 +868,6 @@ coverage_ignore_functions = [
|
||||
"reverse_dict",
|
||||
# torch.fx.experimental.unification.multipledispatch.variadic
|
||||
"isvariadic",
|
||||
# torch.fx.experimental.unification.unification_tools
|
||||
"assoc",
|
||||
"assoc_in",
|
||||
"dissoc",
|
||||
"first",
|
||||
"get_in",
|
||||
"getter",
|
||||
"groupby",
|
||||
"itemfilter",
|
||||
"itemmap",
|
||||
"keyfilter",
|
||||
"keymap",
|
||||
"merge",
|
||||
"merge_with",
|
||||
"update_in",
|
||||
"valfilter",
|
||||
"valmap",
|
||||
# torch.fx.experimental.unification.utils
|
||||
"freeze",
|
||||
"hashable",
|
||||
"raises",
|
||||
|
||||
@ -12,6 +12,37 @@ These APIs are experimental and subject to change without notice.
|
||||
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
|
||||
```
|
||||
|
||||
## torch.fx.experimental.sym_node
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
is_channels_last_contiguous_2d
|
||||
is_channels_last_contiguous_3d
|
||||
is_channels_last_strides_2d
|
||||
is_channels_last_strides_3d
|
||||
is_contiguous
|
||||
is_non_overlapping_and_dense_indicator
|
||||
method_to_operator
|
||||
sympy_is_channels_last_contiguous_2d
|
||||
sympy_is_channels_last_contiguous_3d
|
||||
sympy_is_channels_last_strides_2d
|
||||
sympy_is_channels_last_strides_3d
|
||||
sympy_is_channels_last_strides_generic
|
||||
sympy_is_contiguous
|
||||
sympy_is_contiguous_generic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.symbolic_shapes
|
||||
|
||||
```{eval-rst}
|
||||
@ -69,6 +100,25 @@ These APIs are experimental and subject to change without notice.
|
||||
rebind_unbacked
|
||||
resolve_unbacked_bindings
|
||||
is_accessor_node
|
||||
cast_symbool_to_symint_guardless
|
||||
create_contiguous
|
||||
error
|
||||
eval_guards
|
||||
eval_is_non_overlapping_and_dense
|
||||
find_symbol_binding_fx_nodes
|
||||
free_symbols
|
||||
free_unbacked_symbols
|
||||
fx_placeholder_targets
|
||||
fx_placeholder_vals
|
||||
guard_bool
|
||||
guard_float
|
||||
guard_int
|
||||
guard_scalar
|
||||
has_hint
|
||||
has_symbolic_sizes_strides
|
||||
is_nested_int
|
||||
is_symbol_binding_fx_node
|
||||
is_symbolic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.proxy_tensor
|
||||
@ -91,4 +141,46 @@ These APIs are experimental and subject to change without notice.
|
||||
get_proxy_mode
|
||||
maybe_enable_thunkify
|
||||
maybe_disable_thunkify
|
||||
decompose
|
||||
disable_autocast_cache
|
||||
disable_proxy_modes_tracing
|
||||
extract_val
|
||||
fake_signature
|
||||
fetch_object_proxy
|
||||
fetch_sym_proxy
|
||||
has_proxy_slot
|
||||
is_sym_node
|
||||
maybe_handle_decomp
|
||||
proxy_call
|
||||
set_meta
|
||||
set_original_aten_op
|
||||
set_proxy_slot
|
||||
snapshot_fake
|
||||
```
|
||||
|
||||
## torch.fx.experimental.unification.unification_tools
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
assoc
|
||||
assoc_in
|
||||
dissoc
|
||||
first
|
||||
keyfilter
|
||||
keymap
|
||||
merge
|
||||
merge_with
|
||||
update_in
|
||||
valfilter
|
||||
valmap
|
||||
|
||||
@ -1134,7 +1134,6 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.refinement_types
|
||||
.. py:module:: torch.fx.experimental.rewriter
|
||||
.. py:module:: torch.fx.experimental.schema_type_annotation
|
||||
.. py:module:: torch.fx.experimental.sym_node
|
||||
.. py:module:: torch.fx.experimental.unification.core
|
||||
.. py:module:: torch.fx.experimental.unification.dispatch
|
||||
.. py:module:: torch.fx.experimental.unification.match
|
||||
@ -1144,7 +1143,6 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
|
||||
.. py:module:: torch.fx.experimental.unification.unification_tools
|
||||
.. py:module:: torch.fx.experimental.unification.utils
|
||||
.. py:module:: torch.fx.experimental.unification.variable
|
||||
.. py:module:: torch.fx.experimental.unify_refinements
|
||||
|
||||
21
docs/source/mtia.mtia_graph.md
Normal file
21
docs/source/mtia.mtia_graph.md
Normal file
@ -0,0 +1,21 @@
|
||||
# torch.mtia.mtia_graph
|
||||
|
||||
The MTIA backend is implemented out of the tree, only interfaces are defined here.
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.mtia.mtia_graph
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.mtia.mtia_graph
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: MTIAGraph
|
||||
:members:
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: graph
|
||||
:members:
|
||||
```
|
||||
@ -29,6 +29,7 @@ mps
|
||||
xpu
|
||||
mtia
|
||||
mtia.memory
|
||||
mtia.mtia_graph
|
||||
meta
|
||||
torch.backends <backends>
|
||||
torch.export <export>
|
||||
|
||||
@ -134,6 +134,23 @@ Quantization to work with this as well.
|
||||
ObservationType
|
||||
```
|
||||
|
||||
## torch.ao.quantization.backend_config.utils
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.backend_config.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
entry_to_pretty_str
|
||||
pattern_to_human_readable
|
||||
remove_boolean_dispatch_from_name
|
||||
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.custom_config
|
||||
|
||||
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
|
||||
@ -154,6 +171,30 @@ This module contains a few CustomConfig classes that's used in both eager mode a
|
||||
StandaloneModuleConfigEntry
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.utils
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.fx.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
all_node_args_except_first
|
||||
all_node_args_have_no_tensors
|
||||
collect_producer_nodes
|
||||
create_getattr_from_value
|
||||
create_node_from_old_node_preserve_meta
|
||||
graph_module_from_producer_nodes
|
||||
maybe_get_next_module
|
||||
node_arg_is_bias
|
||||
node_arg_is_weight
|
||||
return_arg_list
|
||||
```
|
||||
|
||||
## torch.ao.quantization.quantizer
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
@ -260,6 +260,7 @@ select = [
|
||||
"TRY401", # verbose-log-message
|
||||
"UP",
|
||||
"YTT",
|
||||
"S101",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pyupgrade]
|
||||
@ -339,6 +340,39 @@ keep-runtime-typing = true
|
||||
"tools/linter/**" = [
|
||||
"LOG015" # please fix
|
||||
]
|
||||
"benchmarks/**" = [
|
||||
"S101"
|
||||
]
|
||||
"test/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torchgen/**" = [
|
||||
"S101"
|
||||
]
|
||||
"torch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"tools/**" = [
|
||||
"S101"
|
||||
]
|
||||
"setup.py" = [
|
||||
"S101"
|
||||
]
|
||||
"functorch/**" = [
|
||||
"S101"
|
||||
]
|
||||
"docs/**" = [
|
||||
"S101"
|
||||
]
|
||||
"android/**" = [
|
||||
"S101"
|
||||
]
|
||||
".github/**" = [
|
||||
"S101"
|
||||
]
|
||||
".ci/**" = [
|
||||
"S101"
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words = "tools/linter/dictionary.txt"
|
||||
|
||||
@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
|
||||
pip install ninja
|
||||
|
||||
# Install onnx
|
||||
pip install --no-use-pep517 -e "$tp2_dir/onnx"
|
||||
pip install -e "$tp2_dir/onnx"
|
||||
|
||||
# Install caffe2 and pytorch
|
||||
pip install -r "$top_dir/caffe2/requirements.txt"
|
||||
|
||||
@ -17,8 +17,11 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_metaprogramming.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_typelist.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_typetraits.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_vec_half.cpp
|
||||
)
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
#include <c10/test/util/Macros.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/headeronly/util/Metaprogramming.h>
|
||||
#include <cstdlib>
|
||||
|
||||
using namespace c10::guts;
|
||||
using namespace torch::headeronly::guts;
|
||||
|
||||
// NOLINTBEGIN(modernize*, cppcoreguidelines-special-member-functions)
|
||||
namespace {
|
||||
@ -65,6 +64,15 @@ static_assert(
|
||||
typename make_function_traits_t<void, typelist::typelist<int, float>>::
|
||||
func_type>::value,
|
||||
"");
|
||||
|
||||
struct Functor final {
|
||||
std::string operator()(int64_t a, float b) const;
|
||||
};
|
||||
static_assert(
|
||||
std::is_same<
|
||||
std::string(int64_t, float),
|
||||
typename infer_function_traits_t<Functor>::func_type>::value,
|
||||
"");
|
||||
} // namespace test_function_traits
|
||||
|
||||
struct MovableOnly {
|
||||
@ -1,8 +1,8 @@
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/headeronly/util/TypeList.h>
|
||||
#include <memory>
|
||||
|
||||
using namespace c10::guts::typelist;
|
||||
using namespace torch::headeronly::guts::typelist;
|
||||
// NOLINTBEGIN(modernize-unary-static-assert)
|
||||
namespace test_size {
|
||||
class MyClass {};
|
||||
@ -1,7 +1,7 @@
|
||||
#include <c10/util/TypeTraits.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/headeronly/util/TypeTraits.h>
|
||||
|
||||
using namespace c10::guts;
|
||||
using namespace torch::headeronly::guts;
|
||||
|
||||
// NOLINTBEGIN(modernize-unary-static-assert)
|
||||
namespace {
|
||||
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
@ -528,6 +529,149 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::Tensor device method
|
||||
|
||||
torch::stable::Device test_tensor_device(torch::stable::Tensor tensor) {
|
||||
return tensor.device();
|
||||
}
|
||||
|
||||
void boxed_test_tensor_device(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
torch::stable::Device res = test_tensor_device(
|
||||
torch::stable::detail::to<torch::stable::Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::Device
|
||||
|
||||
torch::stable::Device test_device_constructor(
|
||||
bool is_cuda,
|
||||
torch::stable::DeviceIndex index,
|
||||
bool use_str) {
|
||||
using torch::stable::Device;
|
||||
using torch::stable::DeviceType;
|
||||
|
||||
if (use_str) {
|
||||
std::string device_str;
|
||||
if (is_cuda) {
|
||||
device_str = "cuda:" + std::to_string(index);
|
||||
} else {
|
||||
device_str = "cpu";
|
||||
}
|
||||
return Device(device_str);
|
||||
} else {
|
||||
if (is_cuda) {
|
||||
return Device(DeviceType::CUDA, index);
|
||||
} else {
|
||||
return Device(DeviceType::CPU);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void boxed_test_device_constructor(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
torch::stable::Device res = test_device_constructor(
|
||||
torch::stable::detail::to<bool>(stack[0]),
|
||||
torch::stable::detail::to<torch::stable::DeviceIndex>(stack[1]),
|
||||
torch::stable::detail::to<bool>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
bool test_device_equality(torch::stable::Device d1, torch::stable::Device d2) {
|
||||
return d1 == d2;
|
||||
}
|
||||
|
||||
void boxed_test_device_equality(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
bool res = test_device_equality(
|
||||
torch::stable::detail::to<torch::stable::Device>(stack[0]),
|
||||
torch::stable::detail::to<torch::stable::Device>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
torch::stable::Device test_device_set_index(
|
||||
torch::stable::Device device,
|
||||
torch::stable::DeviceIndex index) {
|
||||
device.set_index(index);
|
||||
return device;
|
||||
}
|
||||
|
||||
void boxed_test_device_set_index(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
torch::stable::Device res = test_device_set_index(
|
||||
torch::stable::detail::to<torch::stable::Device>(stack[0]),
|
||||
torch::stable::detail::to<torch::stable::DeviceIndex>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
torch::stable::DeviceIndex test_device_index(torch::stable::Device device) {
|
||||
return device.index();
|
||||
}
|
||||
|
||||
void boxed_test_device_index(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
torch::stable::DeviceIndex res = test_device_index(
|
||||
torch::stable::detail::to<torch::stable::Device>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
bool test_device_is_cuda(torch::stable::Device device) {
|
||||
return device.is_cuda();
|
||||
}
|
||||
|
||||
void boxed_test_device_is_cuda(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
bool res = test_device_is_cuda(
|
||||
torch::stable::detail::to<torch::stable::Device>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
bool test_device_is_cpu(torch::stable::Device device) {
|
||||
return device.is_cpu();
|
||||
}
|
||||
|
||||
void boxed_test_device_is_cpu(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
bool res = test_device_is_cpu(
|
||||
torch::stable::detail::to<torch::stable::Device>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("test_tensor_device(Tensor t) -> Device");
|
||||
m.def(
|
||||
"test_device_constructor(bool is_cuda, DeviceIndex index, bool use_str) -> Device");
|
||||
m.def("test_device_equality(Device d1, Device d2) -> bool");
|
||||
m.def("test_device_set_index(Device device, DeviceIndex index) -> Device");
|
||||
m.def("test_device_index(Device device) -> DeviceIndex");
|
||||
m.def("test_device_is_cuda(Device device) -> bool");
|
||||
m.def("test_device_is_cpu(Device device) -> bool");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_tensor_device", &boxed_test_tensor_device);
|
||||
m.impl("test_device_constructor", &boxed_test_device_constructor);
|
||||
m.impl("test_device_equality", &boxed_test_device_equality);
|
||||
m.impl("test_device_set_index", &boxed_test_device_set_index);
|
||||
m.impl("test_device_index", &boxed_test_device_index);
|
||||
m.impl("test_device_is_cuda", &boxed_test_device_is_cuda);
|
||||
m.impl("test_device_is_cpu", &boxed_test_device_is_cpu);
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::accelerator APIs
|
||||
|
||||
#ifdef LAE_USE_CUDA
|
||||
@ -617,3 +761,66 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
}
|
||||
|
||||
#endif // LAE_USE_CUDA
|
||||
|
||||
Tensor test_parallel_for(int64_t size, int64_t grain_size) {
|
||||
AtenTensorHandle tensor_handle;
|
||||
int64_t stride = 1;
|
||||
|
||||
aoti_torch_empty_strided(
|
||||
1,
|
||||
&size,
|
||||
&stride,
|
||||
aoti_torch_dtype_int64(),
|
||||
aoti_torch_device_type_cpu(),
|
||||
0,
|
||||
&tensor_handle);
|
||||
|
||||
Tensor tensor(tensor_handle);
|
||||
int64_t* data_ptr = reinterpret_cast<int64_t*>(tensor.data_ptr());
|
||||
|
||||
torch::stable::zero_(tensor);
|
||||
|
||||
// Use parallel_for to fill each element with its index
|
||||
// If using a parallel path, the thread id is encoded in the upper 32 bits
|
||||
torch::stable::parallel_for(
|
||||
0, size, grain_size, [data_ptr](int64_t begin, int64_t end) {
|
||||
for (auto i = begin; i < end; i++) {
|
||||
STD_TORCH_CHECK(i <= UINT32_MAX);
|
||||
uint32_t thread_id;
|
||||
torch_get_thread_idx(&thread_id);
|
||||
data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32);
|
||||
}
|
||||
});
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void boxed_test_parallel_for(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
Tensor res = test_parallel_for(to<int64_t>(stack[0]), to<int64_t>(stack[1]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
uint32_t test_get_num_threads() {
|
||||
return torch::stable::get_num_threads();
|
||||
}
|
||||
|
||||
void boxed_test_get_num_threads(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
uint32_t res = test_get_num_threads();
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("test_parallel_for(int size, int grain_size) -> Tensor");
|
||||
m.def("test_get_num_threads() -> int");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_parallel_for", &boxed_test_parallel_for);
|
||||
m.impl("test_get_num_threads", &boxed_test_get_num_threads);
|
||||
}
|
||||
|
||||
@ -215,6 +215,18 @@ def test_default_constructor(defined) -> bool:
|
||||
return torch.ops.libtorch_agnostic.test_default_constructor.default(defined)
|
||||
|
||||
|
||||
def test_tensor_device(t):
|
||||
"""
|
||||
Tests Tensor device() method.
|
||||
|
||||
Args:
|
||||
t: Tensor - tensor to get device from
|
||||
|
||||
Returns: Device - device of the tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_tensor_device.default(t)
|
||||
|
||||
|
||||
def my_pad(t) -> Tensor:
|
||||
"""
|
||||
Pads the input tensor with hardcoded padding parameters.
|
||||
@ -375,3 +387,103 @@ def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]:
|
||||
return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default(
|
||||
t1, t2
|
||||
)
|
||||
|
||||
|
||||
def test_device_constructor(is_cuda, index, use_str):
|
||||
"""
|
||||
Tests creating a Device from DeviceType and index, or from a string.
|
||||
|
||||
Args:
|
||||
is_cuda: bool - if True, creates CUDA device; if False, creates CPU device
|
||||
index: int - device index
|
||||
use_str: bool - if True, constructs from string; if False, constructs from DeviceType
|
||||
|
||||
Returns: Device - A device with the specified type and index
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_device_constructor.default(
|
||||
is_cuda, index, use_str
|
||||
)
|
||||
|
||||
|
||||
def test_device_equality(d1, d2) -> bool:
|
||||
"""
|
||||
Tests Device equality operator.
|
||||
|
||||
Args:
|
||||
d1: Device - first device
|
||||
d2: Device - second device
|
||||
|
||||
Returns: bool - True if devices are equal
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_device_equality.default(d1, d2)
|
||||
|
||||
|
||||
def test_device_set_index(device, index):
|
||||
"""
|
||||
Tests Device set_index() method.
|
||||
|
||||
Args:
|
||||
device: Device - device to modify
|
||||
index: int - new device index
|
||||
|
||||
Returns: Device - device with updated index
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_device_set_index.default(device, index)
|
||||
|
||||
|
||||
def test_device_index(device) -> int:
|
||||
"""
|
||||
Tests Device index() method.
|
||||
|
||||
Args:
|
||||
device: Device - device to query
|
||||
|
||||
Returns: int - device index
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_device_index.default(device)
|
||||
|
||||
|
||||
def test_device_is_cuda(device) -> bool:
|
||||
"""
|
||||
Tests Device is_cuda() method.
|
||||
|
||||
Args:
|
||||
device: Device - device to check
|
||||
|
||||
Returns: bool - True if device is CUDA
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_device_is_cuda.default(device)
|
||||
|
||||
|
||||
def test_device_is_cpu(device) -> bool:
|
||||
"""
|
||||
Tests Device is_cpu() method.
|
||||
|
||||
Args:
|
||||
device: Device - device to check
|
||||
|
||||
Returns: bool - True if device is CPU
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_device_is_cpu.default(device)
|
||||
|
||||
|
||||
def test_parallel_for(size, grain_size) -> Tensor:
|
||||
"""
|
||||
Tests the parallel_for functionality by using it to fill a tensor with indices.
|
||||
Args:
|
||||
size: int - size of the tensor to create
|
||||
grain_size: int - grain size for parallel_for
|
||||
Returns: Tensor - a 1D int64 tensor where each element contains its index
|
||||
(if multiple threads are used the threadid will be encoded in the upper 32 bits)
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_parallel_for.default(size, grain_size)
|
||||
|
||||
|
||||
def test_get_num_threads() -> int:
|
||||
"""
|
||||
Tests the get_num_threads functionality by returning the number of threads
|
||||
for the parallel backend.
|
||||
|
||||
Returns: int - the number of threads for the parallel backend
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_get_num_threads.default()
|
||||
|
||||
@ -418,6 +418,113 @@ if not IS_WINDOWS:
|
||||
self.assertEqual(result[0], t1 * t1)
|
||||
self.assertEqual(result[1], t2 * t2)
|
||||
|
||||
@onlyCUDA
|
||||
def test_device(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
cuda_device = libtorch_agnostic.ops.test_device_constructor(
|
||||
is_cuda=True, index=1, use_str=False
|
||||
)
|
||||
self.assertEqual(cuda_device, torch.device("cuda:1"))
|
||||
cuda_device = libtorch_agnostic.ops.test_device_constructor(
|
||||
is_cuda=True, index=1, use_str=True
|
||||
)
|
||||
self.assertEqual(cuda_device, torch.device("cuda:1"))
|
||||
|
||||
self.assertEqual(libtorch_agnostic.ops.test_device_index(cuda_device), 1)
|
||||
self.assertTrue(
|
||||
libtorch_agnostic.ops.test_device_equality(
|
||||
cuda_device, torch.device("cuda:1")
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
libtorch_agnostic.ops.test_device_equality(
|
||||
cuda_device, torch.device("cuda:0")
|
||||
)
|
||||
)
|
||||
self.assertFalse(libtorch_agnostic.ops.test_device_is_cpu(cuda_device))
|
||||
self.assertTrue(libtorch_agnostic.ops.test_device_is_cuda(cuda_device))
|
||||
|
||||
cuda_0_device = libtorch_agnostic.ops.test_device_set_index(cuda_device, 0)
|
||||
self.assertEqual(cuda_0_device, torch.device("cuda:0"))
|
||||
|
||||
cpu_device = libtorch_agnostic.ops.test_device_constructor(False, 0, False)
|
||||
self.assertEqual(cpu_device, torch.device("cpu"))
|
||||
self.assertTrue(
|
||||
libtorch_agnostic.ops.test_device_equality(
|
||||
cpu_device, torch.device("cpu")
|
||||
)
|
||||
)
|
||||
self.assertTrue(libtorch_agnostic.ops.test_device_is_cpu(cpu_device))
|
||||
self.assertFalse(libtorch_agnostic.ops.test_device_is_cuda(cpu_device))
|
||||
self.assertFalse(
|
||||
libtorch_agnostic.ops.test_device_equality(cpu_device, cuda_device)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Device index 129 is out of range for int8_t"
|
||||
):
|
||||
libtorch_agnostic.ops.test_device_constructor(
|
||||
is_cuda=True, index=129, use_str=False
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Device index 129 is out of range for int8_t"
|
||||
):
|
||||
libtorch_agnostic.ops.test_device_set_index(cuda_device, 129)
|
||||
|
||||
@onlyCUDA
|
||||
@deviceCountAtLeast(2)
|
||||
def test_tensor_device(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 3)
|
||||
self.assertEqual(libtorch_agnostic.ops.test_tensor_device(t), t.device)
|
||||
|
||||
t_cuda = torch.randn(2, 3, device="cuda")
|
||||
self.assertEqual(
|
||||
libtorch_agnostic.ops.test_tensor_device(t_cuda), t_cuda.device
|
||||
)
|
||||
|
||||
t_cuda_1 = torch.randn(2, 3, device="cuda:1")
|
||||
self.assertEqual(
|
||||
libtorch_agnostic.ops.test_tensor_device(t_cuda_1), t_cuda_1.device
|
||||
)
|
||||
|
||||
@onlyCPU
|
||||
# TODO: Debug this:
|
||||
# Dynamo failed to run FX node with fake tensors:
|
||||
# call_function libtorch_agnostic.test_parallel_for.default(*(100, 10), **{}):
|
||||
# got RuntimeError('libtorch_agnostic::test_parallel_for() expected at most
|
||||
# 2 argument(s) but received 3 argument(s).
|
||||
# Declaration: libtorch_agnostic::test_parallel_for(int size, int grain_size) -> Tensor')
|
||||
@xfailIfTorchDynamo
|
||||
def test_parallel_for(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
num_threads = torch.get_num_threads()
|
||||
size = 100
|
||||
grain_size = 10
|
||||
expected_num_threads_used = min(
|
||||
(size + grain_size - 1) // grain_size, num_threads
|
||||
)
|
||||
|
||||
result = libtorch_agnostic.ops.test_parallel_for(size, grain_size)
|
||||
result_thread_ids = torch.unique(torch.bitwise_right_shift(result, 32))
|
||||
result_values = torch.bitwise_and(result, 0xFFFFFFFF)
|
||||
expected = torch.arange(size, dtype=torch.int64)
|
||||
|
||||
self.assertEqual(result_values, expected)
|
||||
self.assertEqual(result_thread_ids, torch.arange(expected_num_threads_used))
|
||||
|
||||
@onlyCPU
|
||||
def test_get_num_threads(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
num_threads = libtorch_agnostic.ops.test_get_num_threads()
|
||||
expected_num_threads = torch.get_num_threads()
|
||||
self.assertEqual(num_threads, expected_num_threads)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -140,6 +140,11 @@ static void initDeviceStreamState(DeviceIndex device_index) {
|
||||
static void initOpenRegStreamsOnce() {
|
||||
c10::call_once(init_flag, initGlobalStreamState);
|
||||
|
||||
for (const auto i : c10::irange(num_devices)) {
|
||||
c10::call_once(
|
||||
device_flags[i], initDeviceStreamState, static_cast<DeviceIndex>(i));
|
||||
}
|
||||
|
||||
if (current_streams) {
|
||||
return;
|
||||
}
|
||||
@ -202,8 +207,6 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
|
||||
if (device_index == -1) {
|
||||
device_index = current_device();
|
||||
}
|
||||
c10::call_once(
|
||||
device_flags[device_index], initDeviceStreamState, device_index);
|
||||
auto pri_idx =
|
||||
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
|
||||
const auto idx = get_idx(priority_counters[device_index][pri_idx]);
|
||||
|
||||
@ -180,6 +180,47 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
||||
del model
|
||||
del optim
|
||||
|
||||
def _test_tracker_multihandler_hook(self):
|
||||
"""Should run without KeyError."""
|
||||
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.norm1 = nn.RMSNorm(dim)
|
||||
self.output1 = nn.Linear(dim, dim)
|
||||
self.norm2 = nn.RMSNorm(dim)
|
||||
self.output2 = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.norm1(x)
|
||||
x = self.output1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.output2(x)
|
||||
return x
|
||||
|
||||
gc.collect()
|
||||
torch.manual_seed(42)
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
|
||||
with torch.device(dev):
|
||||
model = TestModule(128)
|
||||
|
||||
mesh = init_device_mesh(dev.type, (self.world_size,))
|
||||
fully_shard([model.norm1, model.output1], mesh=mesh)
|
||||
fully_shard([model.norm2, model.output2], mesh=mesh)
|
||||
fully_shard(model, mesh=mesh)
|
||||
|
||||
fmt = FSDPMemTracker(model)
|
||||
|
||||
with fmt:
|
||||
inp = torch.randn(16, 128, device=dev)
|
||||
y = model(inp)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
del inp
|
||||
del model
|
||||
|
||||
|
||||
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
|
||||
@property
|
||||
|
||||
44
test/distributed/launcher/script_deviceid.py
Normal file
44
test/distributed/launcher/script_deviceid.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Owner(s): ["oncall: r2p"]
|
||||
|
||||
# This is a helper script for
|
||||
# test_run.py::ElasticLaunchTest::test_virtual_local_rank. It prints out the
|
||||
# generated inductor output for a simple function.
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._inductor import codecache
|
||||
|
||||
|
||||
@torch.compile
|
||||
def myfn(x: torch.Tensor) -> torch.Tensor:
|
||||
return x + x
|
||||
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "cuda:0"))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
|
||||
def print_output_code(original_fn):
|
||||
def wrapper(msg, *args, **kwargs):
|
||||
# Check if this is the "Output code:" message
|
||||
if args and "Output code:" in msg:
|
||||
print(args[0])
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
x = torch.rand(2, 2, device="cuda")
|
||||
|
||||
with patch.object(
|
||||
codecache.output_code_log,
|
||||
"debug",
|
||||
side_effect=print_output_code(codecache.output_code_log.debug),
|
||||
):
|
||||
y = myfn(x)
|
||||
|
||||
dist.destroy_process_group()
|
||||
@ -16,7 +16,7 @@ import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
from contextlib import closing, redirect_stderr, redirect_stdout
|
||||
from unittest import mock
|
||||
from unittest import mock, skipIf
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import torch.distributed.run as launch
|
||||
@ -28,6 +28,7 @@ from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_CUDA,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
@ -677,6 +678,96 @@ class ElasticLaunchTest(TestCase):
|
||||
for i in range(nproc_per_node):
|
||||
self.assertTrue(f"[rank{i}]: creating " in captured_out.getvalue())
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
|
||||
)
|
||||
@skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_virtual_local_rank(self):
|
||||
"""
|
||||
Test that virtual-local-rank ensures consistent device IDs across ranks.
|
||||
Without it, ranks may compile to different devices, leading to different code.
|
||||
"""
|
||||
run_id = str(uuid.uuid4().int)
|
||||
nnodes = 1
|
||||
nproc_per_node = 2
|
||||
|
||||
# Helper function to run and capture output
|
||||
def run_test(use_virtual_local_rank):
|
||||
args = [
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--start-method=spawn",
|
||||
"--redirect=3",
|
||||
"--tee=3",
|
||||
]
|
||||
if use_virtual_local_rank:
|
||||
args.append("--virtual-local-rank")
|
||||
|
||||
args.append(path("script_deviceid.py"))
|
||||
|
||||
captured_out = io.StringIO()
|
||||
captured_err = io.StringIO()
|
||||
with redirect_stdout(captured_out), redirect_stderr(captured_err):
|
||||
launch.main(args)
|
||||
|
||||
return captured_out.getvalue()
|
||||
|
||||
def split_ranks(output):
|
||||
default0 = []
|
||||
default1 = []
|
||||
for line in output.splitlines():
|
||||
if "cuda:" not in line:
|
||||
continue
|
||||
if line.startswith("[default0]:"):
|
||||
default0.append(line[11:])
|
||||
elif line.startswith("[default1]:"):
|
||||
default1.append(line[11:])
|
||||
return default0, default1
|
||||
|
||||
# First, run WITHOUT virtual-local-rank - outputs should differ
|
||||
output = run_test(use_virtual_local_rank=False)
|
||||
rank0, rank1 = split_ranks(output)
|
||||
|
||||
# Verify we actually captured compiled code from both ranks
|
||||
self.assertGreater(
|
||||
len(rank0), 0, "Expected to capture compiled code from rank 0"
|
||||
)
|
||||
self.assertGreater(
|
||||
len(rank1), 0, "Expected to capture compiled code from rank 1"
|
||||
)
|
||||
|
||||
# Without virtual-local-rank, the ranks should have DIFFERENT compiled code
|
||||
# because they see different device IDs (cuda:0 vs cuda:1)
|
||||
self.assertNotEqual(
|
||||
rank0,
|
||||
rank1,
|
||||
"Expected different compiled code without --virtual-local-rank",
|
||||
)
|
||||
|
||||
# Now run WITH virtual-local-rank - outputs should be identical
|
||||
output = run_test(use_virtual_local_rank=True)
|
||||
rank0, rank1 = split_ranks(output)
|
||||
|
||||
# Verify we actually captured compiled code from both ranks
|
||||
self.assertGreater(
|
||||
len(rank0),
|
||||
0,
|
||||
"Expected to capture compiled code from rank 0 with --virtual-local-rank",
|
||||
)
|
||||
self.assertGreater(
|
||||
len(rank1),
|
||||
0,
|
||||
"Expected to capture compiled code from rank 1 with --virtual-local-rank",
|
||||
)
|
||||
|
||||
# With virtual-local-rank, both ranks should have IDENTICAL compiled code
|
||||
# because they both see cuda:0 during compilation
|
||||
self.assertEqual(
|
||||
rank0, rank1, "Expected identical compiled code with --virtual-local-rank"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -204,14 +204,16 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
self.assertTrue(b_dt.grad is not None)
|
||||
self.assertTrue(x_dt.grad is None)
|
||||
|
||||
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _run_single_arg_fwd(
|
||||
self, model, arg, placements=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
|
||||
device_mesh = self.build_device_mesh()
|
||||
model_copy = copy.deepcopy(model).to(device=self.device_type)
|
||||
dist_model = distribute_module(model, device_mesh, _conv_fn)
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, placements)
|
||||
out_dt = dist_model(arg_dt.to(device=self.device_type))
|
||||
out = model_copy(arg)
|
||||
out = model_copy(arg_dt.full_tensor())
|
||||
return (out_dt.full_tensor(), out)
|
||||
|
||||
@with_comms
|
||||
@ -219,22 +221,20 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
model = nn.Conv1d(64, 64, 3, padding=1)
|
||||
x = torch.randn(1, 64, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
self.assertEqual(out_dt, out)
|
||||
|
||||
@with_comms
|
||||
def test_conv3d(self):
|
||||
model = nn.Conv3d(64, 64, 3, padding=1)
|
||||
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
|
||||
self.assertEqual(out_dt, out)
|
||||
|
||||
|
||||
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistConvolutionOpsTest,
|
||||
# Send / recv ops are not supported
|
||||
skipped_tests=[
|
||||
"test_conv1d",
|
||||
"test_conv3d",
|
||||
"test_conv_backward_none_grad_inp",
|
||||
"test_depthwise_convolution",
|
||||
"test_downsampling_convolution",
|
||||
|
||||
@ -535,7 +535,7 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
self.assertEqual(fn(z), gm(z)[0])
|
||||
|
||||
def test_dtensor_data_dependent_index(self):
|
||||
def test_dtensor_data_dependent_index_and_slice(self):
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
@ -548,6 +548,35 @@ class DTensorExportTest(TestCase):
|
||||
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
|
||||
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
val = torch.clamp(x.max(), min=1).item()
|
||||
torch._check(val >= 1)
|
||||
return x[:val]
|
||||
|
||||
x = torch.randint(1000, (4, 64, 16))
|
||||
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
|
||||
gm = _dynamo_graph_capture_for_export(Bar())(x_dt)
|
||||
self.assertExpectedInline(
|
||||
"""\
|
||||
graph():
|
||||
%l_flat_args_0_ : [num_users=2] = placeholder[target=arg_0]
|
||||
%max_1 : [num_users=1] = call_method[target=max](args = (%l_flat_args_0_,), kwargs = {})
|
||||
%clamp : [num_users=1] = call_function[target=torch.clamp](args = (%max_1,), kwargs = {min: 1})
|
||||
%item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {})
|
||||
%ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {})
|
||||
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {})
|
||||
%res : [num_users=2] = call_function[target=operator.getitem](args = (%l_flat_args_0_, slice(None, item, None)), kwargs = {})
|
||||
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%res, _local_tensor), kwargs = {})
|
||||
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {})
|
||||
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
|
||||
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {})
|
||||
%le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {})
|
||||
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {})
|
||||
return (res,)""", # noqa: B950
|
||||
str(gm.graph).strip(),
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
@ -22,9 +21,8 @@ from torch.distributed.tensor import (
|
||||
)
|
||||
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -35,7 +33,11 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
generate_shard_orders,
|
||||
make_full_tensor,
|
||||
map_local_tensor_for_rank,
|
||||
patched_distribute_tensor as _distribute_tensor,
|
||||
redistribute,
|
||||
with_comms,
|
||||
)
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
@ -785,88 +787,6 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
else:
|
||||
return ""
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def redistribute(
|
||||
self,
|
||||
dtensor_input,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""
|
||||
wrapper function to support shard_order for redistribution
|
||||
This is a simpler version of Redistribute, only considers the forward.
|
||||
"""
|
||||
if placements is None:
|
||||
placements = self._shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
old_spec = dtensor_input._spec
|
||||
new_spec = copy.deepcopy(old_spec)
|
||||
new_spec.placements = placements
|
||||
if shard_order is not None:
|
||||
new_spec.shard_order = shard_order
|
||||
else:
|
||||
new_spec.shard_order = ()
|
||||
if old_spec == new_spec:
|
||||
return dtensor_input
|
||||
dtensor_input = DTensor.from_local(
|
||||
redistribute_local_tensor(
|
||||
dtensor_input.to_local(),
|
||||
old_spec,
|
||||
new_spec,
|
||||
use_graph_based_transform=use_graph_based_transform,
|
||||
),
|
||||
device_mesh,
|
||||
)
|
||||
dtensor_input._spec = copy.deepcopy(new_spec)
|
||||
return dtensor_input # returns DTensor
|
||||
|
||||
# TODO(zpcore): remove once the native distribute_tensor supports
|
||||
# shard_order arg
|
||||
def distribute_tensor(
|
||||
self,
|
||||
input_tensor,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""wrapper function to support shard_order for tensor distribution"""
|
||||
if placements is None:
|
||||
placements = self._shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
|
||||
# fix the shard order
|
||||
return self.redistribute(
|
||||
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
|
||||
)
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def full_tensor(self, dtensor_input):
|
||||
"""wrapper function to support DTensor.full_tensor"""
|
||||
return self.redistribute(
|
||||
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
|
||||
).to_local()
|
||||
|
||||
def _shard_order_to_placement(self, shard_order, mesh):
|
||||
"""convert shard_order to placement with only Replicate() and Shard()"""
|
||||
placements = [Replicate() for _ in range(mesh.ndim)]
|
||||
if shard_order is not None:
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
for mesh_dim in mesh_dims:
|
||||
placements[mesh_dim] = Shard(tensor_dim)
|
||||
return tuple(placements)
|
||||
|
||||
def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
|
||||
"""Convert shard_order dict to ShardOrder"""
|
||||
return tuple(
|
||||
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
|
||||
for tensor_dim, mesh_dims in shard_order.items()
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_ordered_redistribute(self):
|
||||
"""Test ordered redistribution with various sharding syntaxes"""
|
||||
@ -927,13 +847,11 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
|
||||
sharding_src_dst_pairs_with_expected_trace
|
||||
):
|
||||
sharded_dt = self.distribute_tensor(
|
||||
sharded_dt = _distribute_tensor(
|
||||
input_data.clone(), mesh, src_placement, shard_order=src_order
|
||||
)
|
||||
with DebugMode(record_torchfunction=False) as debug_mode:
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt, mesh, dst_placement, dst_order
|
||||
)
|
||||
sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order)
|
||||
trace_str = self._extract_redistribute_trace_from_debug_mode(
|
||||
debug_mode.debug_string()
|
||||
)
|
||||
@ -957,49 +875,11 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
trace_str,
|
||||
"""S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
|
||||
)
|
||||
expected_dt = self.distribute_tensor(
|
||||
expected_dt = _distribute_tensor(
|
||||
input_data.clone(), mesh, dst_placement, shard_order=dst_order
|
||||
)
|
||||
self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
|
||||
|
||||
def generate_shard_orders(self, mesh, tensor_rank):
|
||||
# Generate all possible sharding placement of tensor with rank
|
||||
# `tensor_rank` over mesh.
|
||||
def _split_list(lst: list, N: int):
|
||||
def compositions(n, k):
|
||||
if k == 1:
|
||||
yield [n]
|
||||
else:
|
||||
for i in range(1, n - k + 2):
|
||||
for tail in compositions(n - i, k - 1):
|
||||
yield [i] + tail
|
||||
|
||||
length = len(lst)
|
||||
for comp in compositions(length, N):
|
||||
result = []
|
||||
start = 0
|
||||
for size in comp:
|
||||
result.append(lst[start : start + size])
|
||||
start += size
|
||||
yield result
|
||||
|
||||
all_mesh = list(range(mesh.ndim))
|
||||
all_device_order = list(itertools.permutations(all_mesh))
|
||||
for device_order in all_device_order:
|
||||
# split on device orders, and assign each device order segment to a tensor dim
|
||||
for num_split in range(1, mesh.ndim + 1):
|
||||
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
|
||||
for tensor_dims in itertools.combinations(
|
||||
range(tensor_rank), len(splitted_list)
|
||||
):
|
||||
shard_order = {}
|
||||
assert len(tensor_dims) == len(splitted_list)
|
||||
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
|
||||
shard_order[tensor_dim] = device_order[
|
||||
mesh_dims[0] : mesh_dims[-1] + 1
|
||||
]
|
||||
yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
|
||||
|
||||
@with_comms
|
||||
def test_generate_shard_orders(self):
|
||||
"""Check if `generate_shard_orders` generates unique sharding combinations"""
|
||||
@ -1012,7 +892,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
]
|
||||
for test_input in test_inputs:
|
||||
all_combinations = []
|
||||
for shard_order in self.generate_shard_orders(
|
||||
for shard_order in generate_shard_orders(
|
||||
test_input["mesh"], test_input["tensor_rank"]
|
||||
):
|
||||
all_combinations.append(shard_order) # noqa: PERF402
|
||||
@ -1062,12 +942,12 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
input_data = torch.randn(tensor_shape, device=self.device_type)
|
||||
tensor_rank = input_data.ndim
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
shard_orders = generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = self.distribute_tensor(
|
||||
sharded_dt = _distribute_tensor(
|
||||
input_data.clone(), mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(self.full_tensor(sharded_dt), input_data)
|
||||
self.assertEqual(make_full_tensor(sharded_dt), input_data)
|
||||
|
||||
# 2. Verify the correctness of redistribution from DTensor to DTensor.
|
||||
# This test repeatedly redistributes a DTensor to various ordered
|
||||
@ -1078,20 +958,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
tensor_rank = input_data.ndim
|
||||
prev_sharded_dt = None
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
shard_orders = generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
if prev_sharded_dt is None:
|
||||
prev_sharded_dt = self.distribute_tensor(
|
||||
prev_sharded_dt = _distribute_tensor(
|
||||
input_data.clone(),
|
||||
mesh,
|
||||
placements=None,
|
||||
shard_order=shard_order,
|
||||
)
|
||||
else:
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt = redistribute(
|
||||
prev_sharded_dt, mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(self.full_tensor(sharded_dt), input_data)
|
||||
self.assertEqual(make_full_tensor(sharded_dt), input_data)
|
||||
prev_sharded_dt = sharded_dt
|
||||
|
||||
@with_comms
|
||||
@ -1136,13 +1016,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
local_tensor = torch.randn(shape, device=self.device_type)
|
||||
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = self.generate_shard_orders(mesh, len(shape))
|
||||
shard_orders = generate_shard_orders(mesh, len(shape))
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt = redistribute(
|
||||
full_tensor, mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(
|
||||
self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
|
||||
make_full_tensor(sharded_dt), make_full_tensor(full_tensor)
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
@ -1152,24 +1032,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_ordered_redistribute_for_special_placement(self):
|
||||
"""Test ordered redistribution with special placement"""
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
|
||||
torch.manual_seed(21)
|
||||
mesh = init_device_mesh(self.device_type, (8,))
|
||||
input_data = torch.randn((8, 8), device=self.device_type)
|
||||
src_placement = [Shard(1)]
|
||||
tgt_placement = [
|
||||
(_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
|
||||
(MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
|
||||
]
|
||||
sharded_dt = self.distribute_tensor(
|
||||
sharded_dt = _distribute_tensor(
|
||||
input_data.clone(),
|
||||
mesh,
|
||||
src_placement,
|
||||
shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
|
||||
)
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt, mesh, tgt_placement, shard_order=None
|
||||
)
|
||||
sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None)
|
||||
|
||||
@with_comms
|
||||
def test_shard_order_same_data_as_strided_shard(self):
|
||||
@ -1179,7 +1055,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
|
||||
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
|
||||
# specify right-to-left order use ordered shard
|
||||
x_ordered_dt = self.distribute_tensor(
|
||||
x_ordered_dt = _distribute_tensor(
|
||||
x,
|
||||
device_mesh,
|
||||
placements=[Shard(0), Shard(0)],
|
||||
|
||||
@ -706,11 +706,11 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_dtensor_dtype_conversion(self):
|
||||
from torch.distributed.tensor.debug import (
|
||||
_clear_sharding_prop_cache,
|
||||
_get_sharding_prop_cache_info,
|
||||
_clear_fast_path_sharding_prop_cache,
|
||||
_get_fast_path_sharding_prop_cache_stats,
|
||||
)
|
||||
|
||||
_clear_sharding_prop_cache()
|
||||
_clear_fast_path_sharding_prop_cache()
|
||||
device_mesh = self.build_device_mesh()
|
||||
shard_spec = [Shard(0)]
|
||||
# by default we start from bf16 dtype
|
||||
@ -730,13 +730,13 @@ class DistTensorOpsTest(DTensorTestBase):
|
||||
self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16)
|
||||
|
||||
# by this point we only have cache misses
|
||||
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
||||
hits, misses = _get_fast_path_sharding_prop_cache_stats()
|
||||
self.assertEqual(hits, 0)
|
||||
self.assertEqual(misses, 2)
|
||||
|
||||
# convert to fp32 again and see if there's cache hit
|
||||
bf16_sharded_dtensor1.float()
|
||||
hits, misses, _, _ = _get_sharding_prop_cache_info()
|
||||
hits, misses = _get_fast_path_sharding_prop_cache_stats()
|
||||
# by now we should have cache hit
|
||||
self.assertEqual(hits, 1)
|
||||
self.assertEqual(misses, 2)
|
||||
|
||||
@ -34,6 +34,10 @@ from torch.distributed.tensor.placement_types import (
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
generate_shard_orders,
|
||||
LocalDTensorTestBase,
|
||||
patched_distribute_tensor as _distribute_tensor,
|
||||
shard_order_to_placement,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
@ -774,6 +778,63 @@ class TestStridedSharding(DTensorTestBase):
|
||||
self.assertEqual(dtensor.full_tensor(), tensor)
|
||||
|
||||
|
||||
class Test_StridedShard_with_shard_order(LocalDTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 32
|
||||
|
||||
@with_comms
|
||||
def test_StridedShard_to_shard_order(self):
|
||||
with LocalTensorMode(ranks=self.world_size):
|
||||
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2))
|
||||
shard_iter = generate_shard_orders(mesh, 3)
|
||||
# It takes ~4.8h to complete total 2520 shard order combinations here
|
||||
# using LocalTensor. So we only randomly pick 25 shard orders to test.
|
||||
all_shard_order = list(shard_iter)
|
||||
import random
|
||||
|
||||
random.seed(42)
|
||||
shard_order_choices = random.sample(
|
||||
all_shard_order, min(25, len(all_shard_order))
|
||||
)
|
||||
|
||||
x = torch.randn(32, 32, 32)
|
||||
for shard_order in shard_order_choices:
|
||||
a = _distribute_tensor(x, mesh, None, shard_order)
|
||||
|
||||
placement_without_stridedshard = shard_order_to_placement(
|
||||
shard_order, mesh
|
||||
)
|
||||
placements_with_stridedshard = (
|
||||
DTensorSpec._convert_shard_order_to_StridedShard(
|
||||
shard_order, placement_without_stridedshard, mesh
|
||||
)
|
||||
)
|
||||
b = distribute_tensor(x, mesh, placements_with_stridedshard)
|
||||
shard_order_from_stridedshard = (
|
||||
DTensorSpec._maybe_convert_StridedShard_to_shard_order(
|
||||
placements_with_stridedshard, mesh
|
||||
)
|
||||
)
|
||||
self.assertEqual(shard_order, shard_order_from_stridedshard)
|
||||
self.assertEqual(a.to_local(), b.to_local())
|
||||
|
||||
@with_comms
|
||||
def test_StridedShard_not_convertible_to_shard_order(self):
|
||||
with LocalTensorMode(ranks=self.world_size):
|
||||
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8))
|
||||
unconvertible_placements_list = [
|
||||
[_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)],
|
||||
[_StridedShard(0, split_factor=2), Shard(1)],
|
||||
[_StridedShard(1, split_factor=16), Shard(1)],
|
||||
]
|
||||
for placements in unconvertible_placements_list:
|
||||
shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order(
|
||||
tuple(placements), mesh
|
||||
)
|
||||
self.assertIsNone(shard_order)
|
||||
|
||||
|
||||
class Test2DStridedLocalShard(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
@ -938,13 +999,25 @@ class TestExplicitRedistribute(LocalTensorTestBase):
|
||||
|
||||
dx = distribute_tensor(x, device_mesh, [Shard(0)])
|
||||
dA = distribute_tensor(A, device_mesh, [Replicate()])
|
||||
with ExplicitRedistributionContext():
|
||||
with ExplicitRedistributionContext(strict=True):
|
||||
dY = torch.matmul(dx, dA_repl)
|
||||
loss = dY.sum()
|
||||
|
||||
# we now see the error during backwards
|
||||
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
|
||||
loss.backward()
|
||||
loss.backward(retain_graph=True)
|
||||
|
||||
with ExplicitRedistributionContext(strict=False):
|
||||
# but since it's a 'free' redistribute, we can still do it under non-strict mode
|
||||
loss.backward(retain_graph=True)
|
||||
|
||||
with ExplicitRedistributionContext(enable=False):
|
||||
# and we can disable
|
||||
loss.backward(retain_graph=True)
|
||||
|
||||
# and re-enable
|
||||
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
|
||||
loss.backward(retain_graph=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1062,6 +1062,307 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
|
||||
def get_toy_model(device_type: str):
|
||||
"""
|
||||
Helper to construct a small multi-layer ToyModel
|
||||
"""
|
||||
|
||||
class ToyBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.wq = torch.nn.Linear(4, 4)
|
||||
self.wk = torch.nn.Linear(4, 4)
|
||||
self.proj = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
attn = self.wq(x) + self.wk(x)
|
||||
return self.proj(torch.nn.functional.relu(attn))
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList([ToyBlock() for _ in range(2)])
|
||||
self.norm = torch.nn.LayerNorm(4)
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.layers:
|
||||
x = blk(x)
|
||||
return self.norm(x)
|
||||
|
||||
model = ToyModel().to(device_type)
|
||||
return model
|
||||
|
||||
|
||||
def apply_manual_reordering_and_get_graph(graph, module_bucket_plans, out_li) -> None:
|
||||
gm = graph.owning_module
|
||||
from torch._inductor.fx_passes.overlap_manual_scheduling import (
|
||||
ManualOverlapScheduler,
|
||||
)
|
||||
|
||||
for node in list(gm.graph.nodes):
|
||||
if (
|
||||
node.name == "all_gather_into_tensor"
|
||||
or node.name == "all_gather_into_tensor_1"
|
||||
or node.name == "wait_tensor"
|
||||
or node.name == "wait_tensor_1"
|
||||
):
|
||||
node.meta["nn_module_stack"] = {"test": ["module_1", ""]}
|
||||
if (
|
||||
node.name == "all_gather_into_tensor_2"
|
||||
or node.name == "all_gather_into_tensor_3"
|
||||
or node.name == "wait_tensor_2"
|
||||
or node.name == "wait_tensor_3"
|
||||
):
|
||||
node.meta["nn_module_stack"] = {"test": ["module_2", ""]}
|
||||
|
||||
overlapped_gm = ManualOverlapScheduler(
|
||||
gm, module_bucket_plans, insert_overlap_deps=False
|
||||
).run()
|
||||
overlapped_gm.graph.lint()
|
||||
out_li.append(overlapped_gm.graph)
|
||||
|
||||
|
||||
def run_and_get_manual_aten_graph(fn, module_bucket_plans, *inputs):
|
||||
li = []
|
||||
apply = functools.partial(
|
||||
apply_manual_reordering_and_get_graph,
|
||||
module_bucket_plans=module_bucket_plans,
|
||||
out_li=li,
|
||||
)
|
||||
with torch._inductor.config.patch(post_grad_custom_post_pass=apply):
|
||||
out = fn(*inputs)
|
||||
|
||||
return out, li[0]
|
||||
|
||||
|
||||
class TestManualOverlapBucketing(TestComputeCommReorderingMultiProc):
|
||||
"""
|
||||
Tests for manual overlap scheduling and subgraph utilities.
|
||||
"""
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_make_graph_view_and_get_subgraph_by_path(self):
|
||||
from torch._inductor.fx_passes.graph_view import (
|
||||
get_subgraph_by_path,
|
||||
make_graph_view,
|
||||
)
|
||||
|
||||
model = get_toy_model(device_type)
|
||||
gm = torch.fx.symbolic_trace(model)
|
||||
graph_view = make_graph_view(gm.graph)
|
||||
# Fetch subgraph for first transformer layer
|
||||
sub_nodes = get_subgraph_by_path(graph_view, "layers.0.wq")
|
||||
self.assertEqual([n.name for n in sub_nodes], ["layers_0_wq"])
|
||||
|
||||
# Fetch multiple paths at once
|
||||
multi_nodes = get_subgraph_by_path(graph_view, ["layers.0.wq", "layers.0.proj"])
|
||||
self.assertEqual(
|
||||
[n.name for n in multi_nodes], ["layers_0_wq", "layers_0_proj"]
|
||||
)
|
||||
|
||||
# Fetch non existing paths
|
||||
non_exist_nodes = get_subgraph_by_path(graph_view, "nonexistent.module.path")
|
||||
self.assertEqual(non_exist_nodes, [])
|
||||
|
||||
# Fetch mixed of existing and non existing paths
|
||||
mixed_nodes = get_subgraph_by_path(
|
||||
graph_view, ["layers.0.wq", "nonexistent.module.path"]
|
||||
)
|
||||
self.assertEqual([n.name for n in mixed_nodes], ["layers_0_wq"])
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_manual_reordering_bucketing_pass_separate_buckets(
|
||||
self,
|
||||
):
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5 # Use a to avoid fusion
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
|
||||
# Use first 8x8 elements to match mm1's shape
|
||||
intermediate = ag1[:8, :8] + ag2[:8, :8]
|
||||
|
||||
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
|
||||
mm2 = torch.matmul(mm1 + intermediate, c[:8])
|
||||
|
||||
# Use all results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
return result
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph = run_and_get_manual_aten_graph(
|
||||
compiled, ["module_1", "module_2"], a, b, c, d
|
||||
)
|
||||
|
||||
(
|
||||
FileCheck()
|
||||
.check("_pre_bucket_all_gather")
|
||||
.check("all_gather_into_tensor_out")
|
||||
.check("_pre_bucket_all_gather_1")
|
||||
.check("all_gather_into_tensor_out_1")
|
||||
.check("wait_tensor_4")
|
||||
.check("wait_tensor_5")
|
||||
.run(str(aten_graph))
|
||||
)
|
||||
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_bucketing_reordering_pass_no_bucket(
|
||||
self,
|
||||
):
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5 # Use a to avoid fusion
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
|
||||
# Use first 8x8 elements to match mm1's shape
|
||||
intermediate = ag1[:8, :8] + ag2[:8, :8]
|
||||
|
||||
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
|
||||
mm2 = torch.matmul(mm1 + intermediate, c[:8])
|
||||
|
||||
# Use all results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
return result
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph = run_and_get_manual_aten_graph(compiled, [], a, b, c, d)
|
||||
|
||||
(
|
||||
FileCheck()
|
||||
.check("all_gather_into_tensor")
|
||||
.check("all_gather_into_tensor_1")
|
||||
.check("all_gather_into_tensor_2")
|
||||
.check("all_gather_into_tensor_3")
|
||||
.check("wait_tensor")
|
||||
.check("wait_tensor_1")
|
||||
.check("wait_tensor_2")
|
||||
.check("wait_tensor_3")
|
||||
.run(str(aten_graph))
|
||||
)
|
||||
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_bucketing_reordering_pass_single_bucket(
|
||||
self,
|
||||
):
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5 # Use a to avoid fusion
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
|
||||
# Use first 8x8 elements to match mm1's shape
|
||||
intermediate = ag1[:8, :8] + ag2[:8, :8]
|
||||
|
||||
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
|
||||
mm2 = torch.matmul(mm1 + intermediate, c[:8])
|
||||
|
||||
# Use all results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
return result
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph = run_and_get_manual_aten_graph(
|
||||
compiled, [["module_1", "module_2"]], a, b, c, d
|
||||
)
|
||||
|
||||
(
|
||||
FileCheck()
|
||||
.check("_pre_bucket_all_gather")
|
||||
.check("all_gather_into_tensor_out")
|
||||
.check("wait_tensor_4")
|
||||
.run(str(aten_graph))
|
||||
)
|
||||
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
||||
@ -1341,13 +1341,11 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
assert counter.op_count == 3 # It generates 2 getattr to unpack the array
|
||||
assert same(out, correct)
|
||||
|
||||
# This doesn't work in all cases, and now we properly loudly error.
|
||||
# See: https://github.com/pytorch/pytorch/issues/151240
|
||||
# When differentiable funcols are implemented can revert.
|
||||
@unittest.expectedFailure
|
||||
def test_backwards(self):
|
||||
"""
|
||||
It's probably not that common to need backwards support for collectives.
|
||||
|
||||
However, I wanted to at least see if it was possible to support it as a design goal.
|
||||
"""
|
||||
|
||||
def func(inp):
|
||||
ar = _functional_collectives.all_reduce(inp, "sum", "0")
|
||||
return ar
|
||||
|
||||
@ -2363,6 +2363,34 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(same(output, expected))
|
||||
assert cnt.frame_count == 1
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 13), "math.fma introduced in python 3.13")
|
||||
def test_math_fma(self):
|
||||
def fma_func(a, b, c):
|
||||
return math.fma(a, b, c)
|
||||
|
||||
# Test with scalar constants (constant folding path)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
cfma_scalars = torch._dynamo.optimize_assert(cnt)(fma_func)
|
||||
|
||||
assert cnt.frame_count == 0
|
||||
expected = fma_func(2.0, 3.0, 4.0)
|
||||
output = cfma_scalars(2.0, 3.0, 4.0)
|
||||
self.assertEqual(output, expected)
|
||||
assert cnt.frame_count == 0
|
||||
|
||||
# Test with tensors (Inductor path)
|
||||
cnt2 = torch._dynamo.testing.CompileCounter()
|
||||
cfma_tensors = torch._dynamo.optimize_assert(cnt2)(fma_func)
|
||||
|
||||
assert cnt2.frame_count == 0
|
||||
x = torch.tensor(2.0)
|
||||
y = torch.tensor(3.0)
|
||||
z = torch.tensor(4.0)
|
||||
expected_tensors = x * y + z
|
||||
output_tensors = cfma_tensors(x, y, z)
|
||||
torch.testing.assert_close(output_tensors, expected_tensors)
|
||||
assert cnt2.frame_count == 1
|
||||
|
||||
@make_test
|
||||
def test_numpy_meshgrid(x, y):
|
||||
r1, r2 = np.meshgrid(x.numpy(), y.numpy())
|
||||
|
||||
@ -5788,6 +5788,20 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
|
||||
self.assertTrue(torch.allclose(dynamo_output, output))
|
||||
|
||||
def test_repr(self):
|
||||
class Config:
|
||||
def __repr__(self):
|
||||
return "Config()"
|
||||
|
||||
def forward(x, config):
|
||||
return x * len(repr(config))
|
||||
|
||||
config = Config()
|
||||
x = torch.randn(2, 2)
|
||||
|
||||
compiled = torch.compile(forward, fullgraph=True)
|
||||
compiled(x, config)
|
||||
|
||||
def test_nn_functional_reduction(self):
|
||||
def fn(loss, reduction):
|
||||
reduction_enum = F._Reduction.get_enum(reduction)
|
||||
|
||||
@ -335,6 +335,59 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@requires_multigpu()
|
||||
def test_new_event_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_event
|
||||
|
||||
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
e0_ind = new_event()
|
||||
with torch.Stream(device="cuda:1"):
|
||||
get_external_object_by_index(e0_ind).record()
|
||||
e1_ind = new_event()
|
||||
self.assertNotEqual(e0_ind, e1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(e0_ind),
|
||||
get_external_object_by_index(e1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=event_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_new_stream_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_stream
|
||||
|
||||
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
s0_ind = new_stream()
|
||||
s1_ind = new_stream()
|
||||
self.assertNotEqual(s0_ind, s1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(s0_ind),
|
||||
get_external_object_by_index(s1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=stream_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
@ -523,6 +576,23 @@ class <lambda>(torch.nn.Module):
|
||||
torch.accelerator.set_stream(original_stream)
|
||||
reset_user_object_tracking()
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck_wait_record_stream(self):
|
||||
from torch._dynamo.variables.streams import wait_stream
|
||||
from torch.library import opcheck
|
||||
|
||||
s0 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s2 = torch.Stream()
|
||||
store_user_object_weakrefs(s0, s1, s2)
|
||||
|
||||
sample_inputs = [
|
||||
(0, 1),
|
||||
(2, 0),
|
||||
]
|
||||
for args in sample_inputs:
|
||||
opcheck(wait_stream, args)
|
||||
|
||||
@requires_cuda
|
||||
def test_inductor_lowering(self):
|
||||
with patch("torch._inductor.config.implicit_fallbacks", False):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user