mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 09:04:53 +08:00
Compare commits
132 Commits
ciflow/tru
...
document
| Author | SHA1 | Date | |
|---|---|---|---|
| 778d522b96 | |||
| 50c338c2da | |||
| 3faee20067 | |||
| cafca357fb | |||
| 1e35b3c4e0 | |||
| f363114852 | |||
| 0ec0120b19 | |||
| 8360f34c36 | |||
| 370b1c12d2 | |||
| 6fd1ca28e1 | |||
| 0055f07997 | |||
| 4f8a986b8f | |||
| 5c3fe9fb30 | |||
| 306b344a18 | |||
| 94e634942a | |||
| a4925c0ce0 | |||
| d16627f4d0 | |||
| 8f78999d77 | |||
| 7cddda1234 | |||
| 98b53961b9 | |||
| a3eb275d3c | |||
| 6f31406723 | |||
| f2ae7084eb | |||
| 12d7cc5cd3 | |||
| a2e2e1d8c0 | |||
| b67785d9eb | |||
| 4cd06dc82c | |||
| 41936f4cf6 | |||
| dec9a59992 | |||
| f975bd58af | |||
| af42256db4 | |||
| 39161e73fc | |||
| 3ed90f5a09 | |||
| d41aa187ec | |||
| 8b2137e74a | |||
| a70ef954b9 | |||
| 01a2812f48 | |||
| 3f27100d3e | |||
| 253fd765bd | |||
| abb2f7179e | |||
| b57ab9a3f2 | |||
| fb64da0791 | |||
| 10a9fb641b | |||
| 9420944033 | |||
| 55f01a48af | |||
| 68913d8f2a | |||
| b8be796a57 | |||
| 238dd5517d | |||
| d272ed4b3e | |||
| 70925bdf82 | |||
| 960b0d5f0d | |||
| e0abcee3b5 | |||
| 77bf23d85c | |||
| d2cb183344 | |||
| 38095fbd13 | |||
| ffc9559d9f | |||
| 172d6ed8b8 | |||
| 9a3c4b917e | |||
| df514a6d5a | |||
| 48fe858fef | |||
| 7ab00c7c17 | |||
| 44b1ff54e9 | |||
| daea35df5c | |||
| 7f2a902ea2 | |||
| 9c057d9863 | |||
| 938869e7d3 | |||
| ce6b589545 | |||
| ae25dd51fc | |||
| a61d0de9f9 | |||
| 3ad88924ad | |||
| 3241b9c15f | |||
| 25d4d5107e | |||
| e4fe811be8 | |||
| 82c71af59a | |||
| 7bd704a346 | |||
| ae139b73e0 | |||
| cbaa07e438 | |||
| bc0e2a0d2b | |||
| 0747d95994 | |||
| 0a2cde2f06 | |||
| c7b57d9349 | |||
| 7614338b69 | |||
| a6fa4f9c28 | |||
| 344e6365a0 | |||
| a3c700656f | |||
| 600db525bd | |||
| f6de195616 | |||
| 4a0df39f81 | |||
| 34ac9b61cb | |||
| 9aa92f246f | |||
| a57a14868d | |||
| 47956196d9 | |||
| 6d27a8e509 | |||
| cd62a73dcb | |||
| 4d7f9f3aed | |||
| 2b9ff99535 | |||
| 98a081a24c | |||
| 6c0125dbc0 | |||
| 0fd976b65c | |||
| 9944cac6e6 | |||
| e7fd296930 | |||
| fac85fcfb5 | |||
| 228973df7f | |||
| ed2d514ad8 | |||
| a2f29bcd63 | |||
| 5390324984 | |||
| ae25ec569c | |||
| 8e1f409b8c | |||
| ee6a1ecb0a | |||
| 3c0577bd15 | |||
| 688efd9741 | |||
| 91040f4934 | |||
| 87eccf10e8 | |||
| 5d459dd609 | |||
| 24d69c57cb | |||
| eaa02655ea | |||
| aea57b3aa3 | |||
| 3d1fa40ae1 | |||
| a7fa1a91e3 | |||
| afeec56a5a | |||
| 724463d5a2 | |||
| f79e212733 | |||
| b28b24a9fc | |||
| 17c7170ca6 | |||
| 6a7f5c0d21 | |||
| 512b6b59f0 | |||
| bc1690c7e8 | |||
| 53f5af8c92 | |||
| 4412026949 | |||
| 06d86e58d0 | |||
| 874efa2d72 | |||
| e09fb44ef1 |
@ -181,7 +181,7 @@ case "$tag" in
|
||||
KATEX=yes
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950"
|
||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100"
|
||||
if [[ $tag =~ "benchmarks" ]]; then
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
fi
|
||||
@ -344,7 +344,7 @@ docker build \
|
||||
--build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \
|
||||
--build-arg "KATEX=${KATEX:-}" \
|
||||
--build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \
|
||||
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a;gfx942;gfx1100}" \
|
||||
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}" \
|
||||
--build-arg "IMAGE_NAME=${IMAGE_NAME}" \
|
||||
--build-arg "UCX_COMMIT=${UCX_COMMIT}" \
|
||||
--build-arg "UCC_COMMIT=${UCC_COMMIT}" \
|
||||
|
||||
@ -10,11 +10,6 @@ BAD_SSL = "https://self-signed.badssl.com"
|
||||
|
||||
print("Testing SSL certificate checking for Python:", sys.version)
|
||||
|
||||
if sys.version_info[:2] < (2, 7) or sys.version_info[:2] < (3, 4):
|
||||
print("This version never checks SSL certs; skipping tests")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
EXC = OSError
|
||||
|
||||
print(f"Connecting to {GOOD_SSL} should work")
|
||||
|
||||
@ -233,7 +233,9 @@ if [[ "${BUILD_ENVIRONMENT}" != *cuda* ]]; then
|
||||
export BUILD_STATIC_RUNTIME_BENCHMARK=ON
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *-debug* ]]; then
|
||||
if [[ "$BUILD_ENVIRONMENT" == *-full-debug* ]]; then
|
||||
export CMAKE_BUILD_TYPE=Debug
|
||||
elif [[ "$BUILD_ENVIRONMENT" == *-debug* ]]; then
|
||||
export CMAKE_BUILD_TYPE=RelWithAssert
|
||||
fi
|
||||
|
||||
@ -299,6 +301,11 @@ else
|
||||
python -m build --wheel --no-isolation
|
||||
fi
|
||||
pip_install_whl "$(echo dist/*.whl)"
|
||||
if [[ "$BUILD_ENVIRONMENT" == *full-debug* ]]; then
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/164297
|
||||
# Torch should be importable and that's about it
|
||||
pushd /; python -c "import torch;print(torch.__config__.show(), torch.randn(5) + 1.7)"; popd
|
||||
fi
|
||||
|
||||
if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *vision* ]]; then
|
||||
install_torchvision
|
||||
|
||||
@ -337,13 +337,13 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_python_smoke_b200() {
|
||||
# Targeted smoke tests for B200 - staged approach to avoid too many failures
|
||||
time python test/run_test.py --include test_matmul_cuda inductor/test_fp8 $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
@ -15,37 +15,35 @@ if errorlevel 1 exit /b 1
|
||||
if not errorlevel 0 exit /b 1
|
||||
|
||||
cd %TMP_DIR_WIN%\build\torch\test
|
||||
|
||||
:: Enable delayed variable expansion to make the list
|
||||
setlocal enabledelayedexpansion
|
||||
set EXE_LIST=
|
||||
for /r "." %%a in (*.exe) do (
|
||||
call :libtorch_check "%%~na" "%%~fa"
|
||||
if "%%~na" == "c10_intrusive_ptr_benchmark" (
|
||||
@REM NB: This is not a gtest executable file, thus couldn't be handled by
|
||||
@REM pytest-cpp and is excluded from test discovery by run_test
|
||||
call "%%~fa"
|
||||
if errorlevel 1 goto fail
|
||||
if not errorlevel 0 goto fail
|
||||
) else (
|
||||
if "%%~na" == "verify_api_visibility" (
|
||||
@REM Skip verify_api_visibility as it is a compile-level test
|
||||
) else (
|
||||
set EXE_LIST=!EXE_LIST! cpp/%%~na
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
goto :eof
|
||||
|
||||
:libtorch_check
|
||||
|
||||
cd %CWD%
|
||||
set CPP_TESTS_DIR=%TMP_DIR_WIN%\build\torch\test
|
||||
|
||||
:: Skip verify_api_visibility as it a compile level test
|
||||
if "%~1" == "verify_api_visibility" goto :eof
|
||||
:: Run python test\run_test.py on the list
|
||||
set NO_TD=True && python test\run_test.py --cpp --verbose -i !EXE_LIST!
|
||||
if errorlevel 1 goto fail
|
||||
if not errorlevel 0 goto fail
|
||||
|
||||
echo Running "%~2"
|
||||
if "%~1" == "c10_intrusive_ptr_benchmark" (
|
||||
:: NB: This is not a gtest executable file, thus couldn't be handled by pytest-cpp
|
||||
call "%~2"
|
||||
goto :eof
|
||||
)
|
||||
|
||||
python test\run_test.py --cpp --verbose -i "cpp/%~1"
|
||||
if errorlevel 1 (
|
||||
echo %1 failed with exit code %errorlevel%
|
||||
goto fail
|
||||
)
|
||||
if not errorlevel 0 (
|
||||
echo %1 failed with exit code %errorlevel%
|
||||
goto fail
|
||||
)
|
||||
goto :eof
|
||||
|
||||
:eof
|
||||
exit /b 0
|
||||
|
||||
2
.flake8
2
.flake8
@ -12,7 +12,7 @@ ignore =
|
||||
# to line this up with executable bit
|
||||
EXE001,
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907,B908,B910
|
||||
B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
|
||||
# these ignores are from flake8-comprehensions; please fix!
|
||||
C407,
|
||||
# these ignores are from flake8-logging-format; please fix!
|
||||
|
||||
2
.github/actions/linux-test/action.yml
vendored
2
.github/actions/linux-test/action.yml
vendored
@ -274,8 +274,6 @@ runs:
|
||||
-w /var/lib/jenkins/workspace \
|
||||
"${DOCKER_IMAGE}"
|
||||
)
|
||||
# Propagate download.pytorch.org IP to container
|
||||
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" sudo bash -c "/bin/cat >> /etc/hosts"
|
||||
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
|
||||
docker exec -t "${container_name}" sh -c "pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}"
|
||||
|
||||
|
||||
13
.github/actions/setup-rocm/action.yml
vendored
13
.github/actions/setup-rocm/action.yml
vendored
@ -111,3 +111,16 @@ runs:
|
||||
# This video group ID maps to subgid 1 inside the docker image due to the /etc/subgid entries.
|
||||
# The group name corresponding to group ID 1 can change depending on the OS, so both are necessary.
|
||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd $DEVICE_FLAG --group-add video --group-add $render_gid --group-add daemon --group-add bin --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --network=host" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
@ -33,10 +33,6 @@ runs:
|
||||
)
|
||||
|
||||
echo "CONTAINER_NAME=${container_name}" >> "$GITHUB_ENV"
|
||||
if [[ "${GPU_ARCH_TYPE}" != "rocm" && "${BUILD_ENVIRONMENT}" != "linux-aarch64-binary-manywheel" && "${BUILD_ENVIRONMENT}" != "linux-s390x-binary-manywheel" && "${GPU_ARCH_TYPE}" != "xpu" ]]; then
|
||||
# Propagate download.pytorch.org IP to container. This is only needed on Linux non aarch64 runner
|
||||
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" bash -c "/bin/cat >> /etc/hosts"
|
||||
fi
|
||||
|
||||
docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh"
|
||||
# Generate test script
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -30,6 +30,7 @@ ciflow_push_tags:
|
||||
- ciflow/riscv64
|
||||
- ciflow/rocm
|
||||
- ciflow/rocm-mi300
|
||||
- ciflow/rocm-mi355
|
||||
- ciflow/s390
|
||||
- ciflow/slow
|
||||
- ciflow/torchbench
|
||||
|
||||
@ -177,6 +177,9 @@ jobs:
|
||||
runs-on: linux.rocm.gpu.mi250
|
||||
timeout-minutes: !{{ common.timeout_minutes }}
|
||||
!{{ upload.binary_env(config) }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
|
||||
2
.github/workflows/_linux-test.yml
vendored
2
.github/workflows/_linux-test.yml
vendored
@ -389,8 +389,6 @@ jobs:
|
||||
"${DOCKER_IMAGE}" \
|
||||
${DOCKER_SHELL_CMD}
|
||||
)
|
||||
# Propagate download.pytorch.org IP to container
|
||||
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" sudo bash -c "/bin/cat >> /etc/hosts"
|
||||
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
|
||||
|
||||
if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then
|
||||
|
||||
13
.github/workflows/_rocm-test.yml
vendored
13
.github/workflows/_rocm-test.yml
vendored
@ -102,19 +102,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
|
||||
6
.github/workflows/generated-linux-binary-libtorch-nightly.yml
generated
vendored
6
.github/workflows/generated-linux-binary-libtorch-nightly.yml
generated
vendored
@ -358,6 +358,9 @@ jobs:
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -473,6 +476,9 @@ jobs:
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
|
||||
LIBTORCH_CONFIG: release
|
||||
LIBTORCH_VARIANT: shared-with-deps
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
|
||||
42
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
42
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
@ -347,6 +347,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.10"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -459,6 +462,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
|
||||
DESIRED_PYTHON: "3.10"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -941,6 +947,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.11"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -1053,6 +1062,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
|
||||
DESIRED_PYTHON: "3.11"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -1535,6 +1547,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.12"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -1647,6 +1662,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
|
||||
DESIRED_PYTHON: "3.12"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -2129,6 +2147,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.13"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -2241,6 +2262,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
|
||||
DESIRED_PYTHON: "3.13"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -2723,6 +2747,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.13t"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -2835,6 +2862,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
|
||||
DESIRED_PYTHON: "3.13t"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -3317,6 +3347,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.14"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -3429,6 +3462,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
|
||||
DESIRED_PYTHON: "3.14"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -3911,6 +3947,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||
DESIRED_PYTHON: "3.14t"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
@ -4023,6 +4062,9 @@ jobs:
|
||||
DOCKER_IMAGE: manylinux2_28-builder
|
||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.0
|
||||
DESIRED_PYTHON: "3.14t"
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
|
||||
2
.github/workflows/h100-distributed.yml
vendored
2
.github/workflows/h100-distributed.yml
vendored
@ -37,7 +37,7 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: "linux.12xlarge"
|
||||
runner: "linux.c7i.12xlarge"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '9.0'
|
||||
|
||||
@ -130,7 +130,7 @@ jobs:
|
||||
name: test-periodically
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
if: github.event.schedule == '15 0,12 * * 1-6'
|
||||
if: github.event.schedule == '15 0 * * 1-6'
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
|
||||
|
||||
7
.github/workflows/lint.yml
vendored
7
.github/workflows/lint.yml
vendored
@ -12,6 +12,7 @@ on:
|
||||
- landchecks/*
|
||||
tags:
|
||||
- ciflow/pull/*
|
||||
- ciflow/trunk/*
|
||||
workflow_dispatch:
|
||||
|
||||
permissions: read-all
|
||||
@ -32,10 +33,12 @@ jobs:
|
||||
name: Get changed files
|
||||
uses: ./.github/workflows/_get-changed-files.yml
|
||||
with:
|
||||
all_files: ${{ contains(github.event.pull_request.labels.*.name, 'lint-all-files') || contains(github.event.pull_request.labels.*.name, 'Reverted') }}
|
||||
all_files: ${{ contains(github.event.pull_request.labels.*.name, 'lint-all-files') || contains(github.event.pull_request.labels.*.name, 'Reverted') || github.event_name == 'push' }}
|
||||
|
||||
lintrunner-clang:
|
||||
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
||||
# Needed to prevent deduping on HUD
|
||||
name: lintrunner-clang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
|
||||
needs: [get-label-type, get-changed-files]
|
||||
# Only run if there are changed files relevant to clangtidy / clangformat
|
||||
if: |
|
||||
@ -75,6 +78,7 @@ jobs:
|
||||
# fails to find types when it should
|
||||
lintrunner-mypy:
|
||||
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
||||
name: lintrunner-mypy-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
|
||||
needs: [get-label-type, get-changed-files]
|
||||
# Only run if there are changed files relevant to mypy
|
||||
if: |
|
||||
@ -99,6 +103,7 @@ jobs:
|
||||
|
||||
lintrunner-noclang:
|
||||
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
||||
name: lintrunner-noclang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
|
||||
needs: [get-label-type, get-changed-files]
|
||||
with:
|
||||
timeout: 120
|
||||
|
||||
10
.github/workflows/periodic.yml
vendored
10
.github/workflows/periodic.yml
vendored
@ -182,11 +182,11 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "nogpu_AVX512", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
|
||||
{ config: "nogpu_AVX512", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
|
||||
{ config: "nogpu_AVX512", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
|
||||
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
|
||||
{ config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
|
||||
{ config: "nogpu_AVX512", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
|
||||
{ config: "nogpu_AVX512", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
|
||||
{ config: "nogpu_AVX512", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
|
||||
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
|
||||
{ config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
|
||||
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.4xlarge.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
1
.github/workflows/pull.yml
vendored
1
.github/workflows/pull.yml
vendored
@ -127,6 +127,7 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner: linux.2xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py3.10-clang18-asan
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
|
||||
|
||||
7
.github/workflows/rocm-mi355.yml
vendored
7
.github/workflows/rocm-mi355.yml
vendored
@ -1,6 +1,9 @@
|
||||
name: rocm-mi355
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/rocm-mi355/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: 30 11,1 * * * # about 4:30am PDT and 6:30pm PDT
|
||||
@ -64,5 +67,7 @@ jobs:
|
||||
build-environment: linux-noble-rocm-py3.12-mi355
|
||||
docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }}
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
|
||||
tests-to-include: >-
|
||||
${{ github.event_name == 'schedule' && 'test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor test_matmul_cuda test_scaled_matmul_cuda'
|
||||
|| '' }}
|
||||
secrets: inherit
|
||||
|
||||
1
.github/workflows/slow.yml
vendored
1
.github/workflows/slow.yml
vendored
@ -140,6 +140,7 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner: linux.2xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py3.10-clang18-asan
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
|
||||
|
||||
13
.github/workflows/trunk.yml
vendored
13
.github/workflows/trunk.yml
vendored
@ -56,7 +56,7 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
build-generates-artifacts: false
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: "linux.4xlarge"
|
||||
runner: "linux.c7i.4xlarge"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 1 },
|
||||
@ -249,3 +249,14 @@ jobs:
|
||||
docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-py3_10-gcc11-full-debug-build-only:
|
||||
name: linux-jammy-py3.10-gcc11-full-debug-build-only
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.2xlarge.memory
|
||||
build-environment: linux-jammy-py3.10-gcc11-full-debug-build-only
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
|
||||
secrets: inherit
|
||||
|
||||
4
.github/workflows/xpu.yml
vendored
4
.github/workflows/xpu.yml
vendored
@ -35,7 +35,7 @@ jobs:
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-1-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-1-py3
|
||||
runner: linux.12xlarge
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
|
||||
@ -56,7 +56,7 @@ jobs:
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
runner: linux.12xlarge
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 8, runner: "linux.idc.xpu" },
|
||||
|
||||
@ -388,9 +388,9 @@ cmake_dependent_option(USE_PRIORITIZED_TEXT_FOR_LD "Use prioritized text linker
|
||||
|
||||
option(USE_MIMALLOC "Use mimalloc" OFF)
|
||||
# Enable third party mimalloc library to improve memory allocation performance
|
||||
# on Windows.
|
||||
# on Windows and AArch64.
|
||||
option(USE_MIMALLOC_ON_MKL "Use mimalloc on MKL" OFF)
|
||||
if(WIN32)
|
||||
if(WIN32 OR (CPU_AARCH64 AND NOT APPLE))
|
||||
set(USE_MIMALLOC ON)
|
||||
|
||||
# Not enable USE_MIMALLOC_ON_MKL due to it caused issue:
|
||||
|
||||
@ -28,4 +28,19 @@ inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
|
||||
return stream << BlasBackendToString(backend);
|
||||
}
|
||||
|
||||
namespace blas {
|
||||
|
||||
enum class ScalingType : std::uint8_t {
|
||||
TensorWise, // fp32 scales
|
||||
RowWise, // fp32 scales
|
||||
BlockWise1x16, // fp8_e4m3fn scales
|
||||
BlockWise1x32, // fp8_e8m0fnu scales
|
||||
BlockWise1x128, // fp32 scales
|
||||
BlockWise128x128, // fp32 scales
|
||||
};
|
||||
|
||||
enum class SwizzleType : std::uint8_t { NO_SWIZZLE = 0, SWIZZLE_32_4_4 = 1 };
|
||||
|
||||
} // namespace blas
|
||||
|
||||
} // namespace at
|
||||
|
||||
@ -16,8 +16,8 @@ inline void check_size_nonnegative(ArrayRef<int64_t> size) {
|
||||
|
||||
inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
|
||||
for (const auto& x : size) {
|
||||
TORCH_CHECK(
|
||||
x.expect_size(__FILE__, __LINE__),
|
||||
TORCH_SYM_CHECK(
|
||||
x.sym_ge(0),
|
||||
"Trying to create tensor with negative dimension ",
|
||||
x,
|
||||
": ",
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/util/DimVector.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
@ -26,9 +27,7 @@ inline void infer_size_impl(
|
||||
std::optional<int64_t> infer_dim;
|
||||
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
|
||||
if (TORCH_GUARD_OR_FALSE(sym_eq(shape[dim], -1))) {
|
||||
if (infer_dim) {
|
||||
throw std::runtime_error("only one dimension can be inferred");
|
||||
}
|
||||
TORCH_CHECK(!infer_dim, "only one dimension can be inferred");
|
||||
infer_dim = dim;
|
||||
} else {
|
||||
// in case of unbacked shape[dim] we assume it's not -1 and add a runtime
|
||||
|
||||
@ -103,9 +103,7 @@ std::string get_cpu_capability() {
|
||||
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
|
||||
case native::CPUCapability::ZVECTOR:
|
||||
return "Z VECTOR";
|
||||
#elif defined(HAVE_SVE_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
case native::CPUCapability::SVE128:
|
||||
return "SVE128";
|
||||
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
case native::CPUCapability::SVE256:
|
||||
return "SVE256";
|
||||
#else
|
||||
|
||||
@ -102,31 +102,8 @@ struct VecReduceAllSIMD<float, Op> {
|
||||
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) &&
|
||||
// !defined(C10_MOBILE)
|
||||
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
|
||||
#if defined(CPU_CAPABILITY_SVE256)
|
||||
template <typename Op>
|
||||
struct VecReduceAllSIMD<float, Op> {
|
||||
static inline float apply(
|
||||
const Op& vec_fun,
|
||||
const Vectorized<float>& acc_vec) {
|
||||
using Vec = Vectorized<float>;
|
||||
Vec v = acc_vec;
|
||||
// 128-bit shuffle
|
||||
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
|
||||
Vec v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
// 64-bit shuffle
|
||||
ind = svdupq_n_u32(2, 3, 0, 1);
|
||||
v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
// 32-bit shuffle
|
||||
ind = svdupq_n_u32(1, 0, 2, 3);
|
||||
v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
return svlasta(svpfalse(), v);
|
||||
}
|
||||
};
|
||||
#else
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||
!defined(CPU_CAPABILITY_SVE)
|
||||
template <typename Op>
|
||||
struct VecReduceAllSIMD<float, Op> {
|
||||
static inline float apply(
|
||||
@ -163,8 +140,35 @@ struct VecReduceAllSIMD<float, std::plus<Vectorized<float>>> {
|
||||
return vaddvq_f32(acc_vec);
|
||||
}
|
||||
};
|
||||
#endif // defined(CPU_CAPABILITY_SVE256)
|
||||
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
|
||||
// && !defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||
defined(CPU_CAPABILITY_SVE256)
|
||||
template <typename Op>
|
||||
struct VecReduceAllSIMD<float, Op> {
|
||||
static inline float apply(
|
||||
const Op& vec_fun,
|
||||
const Vectorized<float>& acc_vec) {
|
||||
using Vec = Vectorized<float>;
|
||||
Vec v = acc_vec;
|
||||
// 128-bit shuffle
|
||||
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
|
||||
Vec v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
// 64-bit shuffle
|
||||
ind = svdupq_n_u32(2, 3, 0, 1);
|
||||
v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
// 32-bit shuffle
|
||||
ind = svdupq_n_u32(1, 0, 2, 3);
|
||||
v1 = svtbl_f32(v, ind);
|
||||
v = vec_fun(v, v1);
|
||||
return svlasta(svpfalse(), v);
|
||||
}
|
||||
};
|
||||
#endif // defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
|
||||
// && defined(CPU_CAPABILITY_SVE256)
|
||||
|
||||
template <typename scalar_t, typename Op>
|
||||
inline scalar_t vec_reduce_all(
|
||||
|
||||
@ -1,21 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cstdint>
|
||||
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
|
||||
#if defined(__aarch64__) && \
|
||||
(defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) || \
|
||||
defined(AT_BUILD_ARM_VECSVE_WITH_SLEEF))
|
||||
#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
|
||||
#else
|
||||
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
|
||||
#endif
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
// Define the data type of VLS(vector-length specific).
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/sve/sve_helper.h>
|
||||
#include <ATen/cpu/vec/sve/vec_common_sve.h>
|
||||
#include <ATen/cpu/vec/sve/vec_float.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
@ -307,8 +308,8 @@ Vectorized<c10::BFloat16> inline operator/(
|
||||
}
|
||||
|
||||
inline Vectorized<BFloat16>::Vectorized() {
|
||||
const short zero = 0;
|
||||
values = svdup_n_bf16(c10::bit_cast<bfloat16_t>(zero));
|
||||
auto vals_f = svdup_n_f32(0);
|
||||
values = convert_float_bfloat16(vals_f, vals_f);
|
||||
}
|
||||
|
||||
inline Vectorized<BFloat16>::Vectorized(int val) {
|
||||
|
||||
@ -8,48 +8,13 @@
|
||||
#include <ATen/cpu/vec/sve/sve_helper.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
|
||||
#ifdef CPU_CAPABILITY_SVE128
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||
|
||||
#include <ATen/cpu/vec/sve/vec_qint.h>
|
||||
|
||||
#elif defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
#include <ATen/cpu/vec/sve/vec_float.h>
|
||||
|
||||
#if defined(CPU_CAPABILITY_SVE)
|
||||
#include <ATen/cpu/vec/sve/vec_bfloat16.h>
|
||||
|
||||
#include <ATen/cpu/vec/sve/vec_double.h>
|
||||
#include <ATen/cpu/vec/sve/vec_float.h>
|
||||
#include <ATen/cpu/vec/sve/vec_int.h>
|
||||
|
||||
#include <ATen/cpu/vec/sve/vec_qint.h>
|
||||
|
||||
#include <ATen/cpu/vec/vec256/vec256_half.h>
|
||||
|
||||
#include <ATen/cpu/vec/vec256/vec256_convert.h>
|
||||
|
||||
#else // NEON
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||
|
||||
#include <ATen/cpu/vec/vec256/vec256_qint.h>
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_SVE128)
|
||||
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#endif
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
@ -83,6 +48,12 @@ DEFINE_SVE_CAST(int32_t, s32, float, f32)
|
||||
DEFINE_SVE_CAST(int16_t, s16, float, f32)
|
||||
DEFINE_SVE_CAST(float, f32, double, f64)
|
||||
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
DEFINE_SVE_CAST(int64_t, s64, c10::BFloat16, bf16)
|
||||
DEFINE_SVE_CAST(int32_t, s32, c10::BFloat16, bf16)
|
||||
DEFINE_SVE_CAST(int16_t, s16, c10::BFloat16, bf16)
|
||||
#endif // __ARM_FEATURE_BF16
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
template <int64_t scale = 1>
|
||||
@ -202,11 +173,9 @@ std::pair<
|
||||
// group cols crossing lanes:
|
||||
// return {a0, b0, a1, b1, a2, b2, a3, b3}
|
||||
// {a4, b4, a5, b5, a6, b6, a7, b7}
|
||||
svbfloat16_t aReg = a;
|
||||
svbfloat16_t bReg = b;
|
||||
Vectorized<c10::BFloat16> c = svzip1_bf16(aReg, bReg);
|
||||
Vectorized<c10::BFloat16> d = svzip2_bf16(aReg, bReg);
|
||||
return std::make_pair(c, d);
|
||||
return std::make_pair(
|
||||
Vectorized<c10::BFloat16>(svzip1_bf16(a, b)),
|
||||
Vectorized<c10::BFloat16>(svzip2_bf16(a, b)));
|
||||
}
|
||||
#endif // __ARM_FEATURE_BF16
|
||||
|
||||
@ -255,27 +224,12 @@ std::pair<
|
||||
// swap lanes:
|
||||
// return {a0, a1, a2, a3, a4, a5, a6, a7}
|
||||
// {b0, b1, b2, b3, b4, b5, b6, b7}
|
||||
svbfloat16_t aReg = a;
|
||||
svbfloat16_t bReg = b;
|
||||
Vectorized<c10::BFloat16> c = svuzp1_bf16(aReg, bReg);
|
||||
Vectorized<c10::BFloat16> d = svuzp2_bf16(aReg, bReg);
|
||||
return std::make_pair(c, d);
|
||||
return std::make_pair(
|
||||
Vectorized<c10::BFloat16>(svuzp1_bf16((svbfloat16_t)a, (svbfloat16_t)b)),
|
||||
Vectorized<c10::BFloat16>(svuzp2_bf16((svbfloat16_t)a, (svbfloat16_t)b)));
|
||||
}
|
||||
#endif // __ARM_FEATURE_BF16
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
#define DEFINE_FLIP_FUNC(type, sve_func) \
|
||||
inline Vectorized<type> flip(const Vectorized<type>& v) { \
|
||||
return Vectorized<type>(sve_func(v)); \
|
||||
}
|
||||
// Use the macro to define the flip functions
|
||||
DEFINE_FLIP_FUNC(float, svrev_f32)
|
||||
DEFINE_FLIP_FUNC(double, svrev_f64)
|
||||
DEFINE_FLIP_FUNC(int64_t, svrev_s64)
|
||||
DEFINE_FLIP_FUNC(int32_t, svrev_s32)
|
||||
DEFINE_FLIP_FUNC(int16_t, svrev_s16)
|
||||
DEFINE_FLIP_FUNC(int8_t, svrev_s8)
|
||||
|
||||
#endif // defined(CPU_CAPABILITY_SVE)
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
@ -1,8 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__aarch64__)
|
||||
#include <ATen/cpu/vec/vec_common_aarch64.h>
|
||||
#elif defined(CPU_CAPABILITY_AVX512)
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
#include <ATen/cpu/vec/vec512/vec512.h>
|
||||
#else
|
||||
#include <ATen/cpu/vec/vec128/vec128.h>
|
||||
@ -13,34 +11,6 @@ namespace at::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
|
||||
stream << val.val_;
|
||||
return stream;
|
||||
}
|
||||
inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
|
||||
stream << static_cast<int>(val.val_);
|
||||
return stream;
|
||||
}
|
||||
inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
|
||||
stream << static_cast<unsigned int>(val.val_);
|
||||
return stream;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
T buf[Vectorized<T>::size()];
|
||||
vec.store(buf);
|
||||
stream << "vec[";
|
||||
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
||||
if (i != 0) {
|
||||
stream << ", ";
|
||||
}
|
||||
stream << buf[i];
|
||||
}
|
||||
stream << "]";
|
||||
return stream;
|
||||
}
|
||||
|
||||
inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) {
|
||||
__at_align__ bool buffer[x.size()];
|
||||
x.ne(Vectorized<int8_t>(0)).store(buffer);
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
|
||||
// See Note [Do not compile initializers with AVX]
|
||||
#include <ATen/cpu/vec/sve/sve_helper.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
@ -263,13 +262,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
c10::bit_cast<at_bfloat16_t>(val6.x),
|
||||
c10::bit_cast<at_bfloat16_t>(val7.x)}) {}
|
||||
|
||||
#ifdef CPU_CAPABILITY_SVE128
|
||||
Vectorized(svbfloat16_t v) : Vectorized16(svget_neonq(v)) {}
|
||||
operator svbfloat16_t() const {
|
||||
return svset_neonq(svundef_bf16(), values);
|
||||
}
|
||||
#endif
|
||||
|
||||
static Vectorized<c10::BFloat16> blendv(
|
||||
const Vectorized<c10::BFloat16>& a,
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
@ -382,23 +374,6 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
||||
Vectorized ge(const Vectorized& other) const;
|
||||
Vectorized lt(const Vectorized& other) const;
|
||||
Vectorized le(const Vectorized& other) const;
|
||||
|
||||
#ifdef CPU_CAPABILITY_SVE128
|
||||
|
||||
template <typename step_t>
|
||||
static Vectorized<BFloat16> arange(
|
||||
BFloat16 base = 0.f,
|
||||
step_t step = static_cast<step_t>(1)) {
|
||||
__at_align__ BFloat16 buffer[size()];
|
||||
for (int64_t i = 0; i < size(); i++) {
|
||||
buffer[i] = base + i * step;
|
||||
}
|
||||
return svget_neonq(
|
||||
svld1_bf16(ptrue, reinterpret_cast<bfloat16_t*>(buffer)));
|
||||
}
|
||||
|
||||
#endif // CPU_CAPABILITY_SVE128
|
||||
|
||||
}; // Vectorized<c10::BFloat16>
|
||||
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(
|
||||
@ -422,24 +397,6 @@ inline Vectorized<c10::BFloat16> convert_float_bfloat16(
|
||||
return Vectorized<c10::BFloat16>(at_vcombine_bf16(x1, x2));
|
||||
}
|
||||
|
||||
inline void load_fp32_from_bf16(const BFloat16* data, Vectorized<float>& out) {
|
||||
__at_align__ float values[Vectorized<float>::size()];
|
||||
for (const auto k : c10::irange(Vectorized<float>::size())) {
|
||||
values[k] = data[k];
|
||||
}
|
||||
out = Vectorized<float>::loadu(values);
|
||||
}
|
||||
|
||||
inline void load_fp32_from_bf16(
|
||||
const BFloat16* data,
|
||||
Vectorized<float>& out1,
|
||||
Vectorized<float>& out2) {
|
||||
Vectorized<BFloat16> bf16_vec = Vectorized<BFloat16>::loadu(data);
|
||||
auto floats = convert_bfloat16_float(bf16_vec);
|
||||
out1 = std::get<0>(floats);
|
||||
out2 = std::get<1>(floats);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
Vectorized<c10::BFloat16> binary_operator_via_float(
|
||||
Op op,
|
||||
@ -622,12 +579,6 @@ Vectorized<c10::BFloat16> inline fnmsub(
|
||||
return -a * b - c;
|
||||
}
|
||||
|
||||
#else //
|
||||
|
||||
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
|
||||
|
||||
LOAD_FP32_NON_VECTORIZED_INIT(BFloat16, bf16)
|
||||
|
||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
@ -60,7 +60,6 @@ struct VecConvert<float, 1, BFloat16, 1> {
|
||||
}
|
||||
};
|
||||
|
||||
#endif // defined(__aarch64__) && (!defined(CPU_CAPABILITY_SVE) ||
|
||||
// defined(CPU_CAPABILITY_SVE128))
|
||||
#endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
|
||||
@ -4,10 +4,13 @@
|
||||
// See Note [Do not compile initializers with AVX]
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/sve/sve_helper.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
// Sleef offers vectorized versions of some transcedentals
|
||||
// such as sin, cos, tan etc..
|
||||
// However for now opting for STL, since we are not building
|
||||
@ -32,6 +35,12 @@ inline namespace CPU_CAPABILITY {
|
||||
#error "Big endian is not supported."
|
||||
#endif
|
||||
|
||||
#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF)
|
||||
#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code
|
||||
#else
|
||||
#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code
|
||||
#endif
|
||||
|
||||
template <int index, bool mask_val>
|
||||
struct BlendRegs {
|
||||
static float32x4_t impl(
|
||||
@ -85,12 +94,6 @@ class Vectorized<float> {
|
||||
operator float32x4_t() const {
|
||||
return values;
|
||||
}
|
||||
#ifdef CPU_CAPABILITY_SVE128
|
||||
Vectorized(svfloat32_t v) : values(svget_neonq(v)) {}
|
||||
operator svfloat32_t() const {
|
||||
return svset_neonq(svundef_f32(), values);
|
||||
}
|
||||
#endif
|
||||
template <int64_t mask>
|
||||
static Vectorized<float> blend(
|
||||
const Vectorized<float>& a,
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
// See Note [Do not compile initializers with AVX]
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
@ -24,6 +25,7 @@ inline namespace CPU_CAPABILITY {
|
||||
// https://bugs.llvm.org/show_bug.cgi?id=45824
|
||||
// Most likely we will do aarch32 support with inline asm.
|
||||
#if !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
||||
#ifdef __BIG_ENDIAN__
|
||||
#error "Big endian is not supported."
|
||||
#endif
|
||||
@ -419,24 +421,6 @@ Vectorized<c10::Half> inline operator+(
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void load_fp32_from_fp16(const c10::Half* data, Vectorized<float>& out) {
|
||||
__at_align__ float values[Vectorized<float>::size()];
|
||||
for (const auto k : c10::irange(Vectorized<float>::size())) {
|
||||
values[k] = data[k];
|
||||
}
|
||||
out = Vectorized<float>::loadu(values);
|
||||
}
|
||||
|
||||
inline void load_fp32_from_fp16(
|
||||
const c10::Half* data,
|
||||
Vectorized<float>& out1,
|
||||
Vectorized<float>& out2) {
|
||||
Vectorized<c10::Half> f16_vec = Vectorized<c10::Half>::loadu(data);
|
||||
auto floats = convert_half_float(f16_vec);
|
||||
out1 = std::get<0>(floats);
|
||||
out2 = std::get<1>(floats);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<c10::Half> inline operator-(
|
||||
const Vectorized<c10::Half>& a,
|
||||
@ -672,53 +656,6 @@ Vectorized<c10::Half> inline fnmsub(
|
||||
return -a * b - c;
|
||||
#endif
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define CONVERT_NON_VECTORIZED_INIT(type, name) \
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> \
|
||||
convert_##name##_float(const Vectorized<type>& a) { \
|
||||
constexpr int64_t K = Vectorized<type>::size(); \
|
||||
__at_align__ float arr[K]; \
|
||||
__at_align__ type arr2[K]; \
|
||||
a.store(arr2); \
|
||||
convert(arr2, arr, K); \
|
||||
return std::make_tuple( \
|
||||
Vectorized<float>::loadu(arr), \
|
||||
Vectorized<float>::loadu(arr + Vectorized<float>::size())); \
|
||||
} \
|
||||
inline Vectorized<type> convert_float_##name( \
|
||||
const Vectorized<float>& a, const Vectorized<float>& b) { \
|
||||
constexpr int64_t K = Vectorized<type>::size(); \
|
||||
__at_align__ float arr[K]; \
|
||||
__at_align__ type arr2[K]; \
|
||||
a.store(arr); \
|
||||
b.store(arr + Vectorized<float>::size()); \
|
||||
convert(arr, arr2, K); \
|
||||
return Vectorized<type>::loadu(arr2); \
|
||||
}
|
||||
|
||||
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
|
||||
inline void load_fp32_from_##name( \
|
||||
const type* data, Vectorized<float>& out) { \
|
||||
__at_align__ float values[Vectorized<float>::size()]; \
|
||||
for (const auto k : c10::irange(Vectorized<float>::size())) { \
|
||||
values[k] = data[k]; \
|
||||
} \
|
||||
out = Vectorized<float>::loadu(values); \
|
||||
} \
|
||||
\
|
||||
inline void load_fp32_from_##name( \
|
||||
const type* data, Vectorized<float>& out1, Vectorized<float>& out2) { \
|
||||
load_fp32_from_##name(data, out1); \
|
||||
data += Vectorized<float>::size(); \
|
||||
load_fp32_from_##name(data, out2); \
|
||||
}
|
||||
|
||||
CONVERT_NON_VECTORIZED_INIT(Half, half)
|
||||
|
||||
LOAD_FP32_NON_VECTORIZED_INIT(Half, fp16)
|
||||
|
||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
|
||||
@ -9,16 +9,21 @@
|
||||
#if !( \
|
||||
defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || \
|
||||
defined(CPU_CAPABILITY_ZVECTOR))
|
||||
#include <ATen/cpu/vec/vec256/vec256_double.h>
|
||||
#if defined(CPU_CAPABILITY_SVE256)
|
||||
#include <ATen/cpu/vec/sve/vec_common_sve.h>
|
||||
#else
|
||||
// clang-format off
|
||||
#include <ATen/cpu/vec/vec256/vec256_float.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_double.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_int.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_qint.h>
|
||||
#endif
|
||||
#if !defined(CPU_CAPABILITY_SVE256) || !defined(__ARM_FEATURE_BF16)
|
||||
#include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
|
||||
#endif
|
||||
#include <ATen/cpu/vec/vec256/vec256_complex_double.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_complex_float.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_half.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_complex_float.h>
|
||||
#include <ATen/cpu/vec/vec256/vec256_complex_double.h>
|
||||
// clang-format on
|
||||
#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
|
||||
#include <ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h>
|
||||
@ -51,6 +56,34 @@ namespace at::vec {
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
|
||||
stream << val.val_;
|
||||
return stream;
|
||||
}
|
||||
inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
|
||||
stream << static_cast<int>(val.val_);
|
||||
return stream;
|
||||
}
|
||||
inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
|
||||
stream << static_cast<unsigned int>(val.val_);
|
||||
return stream;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
T buf[Vectorized<T>::size()];
|
||||
vec.store(buf);
|
||||
stream << "vec[";
|
||||
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
||||
if (i != 0) {
|
||||
stream << ", ";
|
||||
}
|
||||
stream << buf[i];
|
||||
}
|
||||
stream << "]";
|
||||
return stream;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -268,7 +268,9 @@ LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16)
|
||||
|
||||
#else // defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
#if !(defined(__aarch64__))
|
||||
#if !( \
|
||||
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||
!defined(CPU_CAPABILITY_SVE256))
|
||||
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16)
|
||||
#endif
|
||||
|
||||
|
||||
@ -268,7 +268,9 @@ LOAD_FP32_VECTORIZED_INIT(Half, fp16)
|
||||
|
||||
#else // defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
#if !defined(__aarch64__) || defined(CPU_CAPABILITY_SVE256)
|
||||
#if !( \
|
||||
defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \
|
||||
!defined(CPU_CAPABILITY_SVE256))
|
||||
CONVERT_NON_VECTORIZED_INIT(Half, half)
|
||||
#endif
|
||||
|
||||
|
||||
@ -5,13 +5,6 @@
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
|
||||
#ifdef __aarch64__
|
||||
#if defined(CPU_CAPABILITY_SVE128) || !defined(CPU_CAPABILITY_SVE)
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <ATen/native/quantized/AffineQuantizerBase.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
@ -922,7 +915,7 @@ Vectorized<c10::quint8> inline maximum(
|
||||
return a.maximum(b);
|
||||
}
|
||||
|
||||
#else
|
||||
#elif !defined(CPU_CAPABILITY_SVE256)
|
||||
|
||||
// NOTE: These are low-performance implementations that we fall back on
|
||||
// if we are not building with AVX2. This may not be an issue, because
|
||||
@ -1379,18 +1372,12 @@ Vectorized<c10::quint8> inline maximum(
|
||||
return a.maximum(b);
|
||||
}
|
||||
|
||||
#if defined(__aarch64__) && \
|
||||
(defined(CPU_CAPABILITY_SVE128) || !defined(CPU_CAPABILITY_SVE))
|
||||
#endif // if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
at::vec::Vectorized<int8_t> src) {
|
||||
|
||||
#ifdef CPU_CAPABILITY_SVE
|
||||
svint8_t x = src;
|
||||
auto s8x8 = vget_low_s8(svget_neonq(x));
|
||||
#else
|
||||
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||
#endif
|
||||
|
||||
auto s16x8 = vmovl_s8(s8x8);
|
||||
|
||||
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
|
||||
@ -1415,14 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
|
||||
Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
at::vec::Vectorized<int8_t> src) {
|
||||
|
||||
#ifdef CPU_CAPABILITY_SVE
|
||||
svint8_t x = src;
|
||||
auto s8x8 = vget_low_s8(svget_neonq(x));
|
||||
#else
|
||||
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||
#endif
|
||||
|
||||
auto s16x8 = vmovl_s8(s8x8);
|
||||
|
||||
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
|
||||
@ -1440,8 +1420,5 @@ Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif // if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
||||
|
||||
@ -31,6 +31,34 @@ namespace vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
|
||||
stream << val.val_;
|
||||
return stream;
|
||||
}
|
||||
inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
|
||||
stream << static_cast<int>(val.val_);
|
||||
return stream;
|
||||
}
|
||||
inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
|
||||
stream << static_cast<unsigned int>(val.val_);
|
||||
return stream;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
T buf[Vectorized<T>::size()];
|
||||
vec.store(buf);
|
||||
stream << "vec[";
|
||||
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
||||
if (i != 0) {
|
||||
stream << ", ";
|
||||
}
|
||||
stream << buf[i];
|
||||
}
|
||||
stream << "]";
|
||||
return stream;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512)
|
||||
|
||||
@ -67,7 +67,18 @@ Windows llvm will not have this definition.
|
||||
#endif
|
||||
#define VECTOR_WIDTH 64
|
||||
#define int_vector __m512i
|
||||
#elif defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_SVE256)
|
||||
#elif defined(__aarch64__) && \
|
||||
!defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512
|
||||
// SVE code expects 256-vectors; leave that set for SVE?
|
||||
#if defined(__GNUC__)
|
||||
#define __at_align__ __attribute__((aligned(16)))
|
||||
#elif defined(_WIN32)
|
||||
#define __at_align__ __declspec(align(16))
|
||||
#else
|
||||
#define __at_align__
|
||||
#endif
|
||||
#define VECTOR_WIDTH 16
|
||||
#else // CPU_CAPABILITY_AVX512
|
||||
#if defined(__GNUC__)
|
||||
#define __at_align__ __attribute__((aligned(32)))
|
||||
#elif defined(_WIN32)
|
||||
@ -77,27 +88,7 @@ Windows llvm will not have this definition.
|
||||
#endif
|
||||
#define VECTOR_WIDTH 32
|
||||
#define int_vector __m256i
|
||||
#elif defined(__aarch64__)
|
||||
// Define alignment and vector width for SVE128/Default (e.g., NEON)
|
||||
#if defined(__GNUC__)
|
||||
#define __at_align__ __attribute__((aligned(16)))
|
||||
#elif defined(_WIN32)
|
||||
#define __at_align__ __declspec(align(16))
|
||||
#else
|
||||
#define __at_align__
|
||||
#endif
|
||||
#define VECTOR_WIDTH 16
|
||||
#else
|
||||
// Fallback: define default alignment and vector width
|
||||
#if defined(__GNUC__)
|
||||
#define __at_align__ __attribute__((aligned(32)))
|
||||
#elif defined(_WIN32)
|
||||
#define __at_align__ __declspec(align(32))
|
||||
#else
|
||||
#define __at_align__
|
||||
#endif
|
||||
#define VECTOR_WIDTH 32
|
||||
#endif
|
||||
#endif // CPU_CAPABILITY_AVX512
|
||||
|
||||
namespace at::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
|
||||
@ -1861,6 +1861,8 @@ template bool gemm_and_bias(
|
||||
int64_t result_ld,
|
||||
GEMMAndBiasActivationEpilogue activation);
|
||||
|
||||
using at::blas::ScalingType;
|
||||
|
||||
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
|
||||
switch (scaling_type) {
|
||||
case ScalingType::BlockWise1x32:
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
|
||||
namespace at::cuda::blas {
|
||||
@ -136,15 +137,6 @@ void int8_gemm(
|
||||
int32_t* result_ptr,
|
||||
int64_t result_ld);
|
||||
|
||||
enum class ScalingType : std::uint8_t {
|
||||
TensorWise, // fp32 scales
|
||||
RowWise, // fp32 scales
|
||||
BlockWise1x16, // fp8_e4m3fn scales
|
||||
BlockWise1x32, // fp8_e8m0fnu scales
|
||||
BlockWise1x128, // fp32 scales
|
||||
BlockWise128x128, // fp32 scales
|
||||
};
|
||||
|
||||
void scaled_gemm(
|
||||
char transa,
|
||||
char transb,
|
||||
@ -156,13 +148,13 @@ void scaled_gemm(
|
||||
int64_t mat1_ld,
|
||||
ScalarType mat1_dtype,
|
||||
ScalarType mat1_scale_dtype,
|
||||
ScalingType mat1_scaling_type,
|
||||
at::blas::ScalingType mat1_scaling_type,
|
||||
const void* mat2_ptr,
|
||||
const void* mat2_scale_ptr,
|
||||
int64_t mat2_ld,
|
||||
ScalarType mat2_dtype,
|
||||
ScalarType mat2_scale_dtype,
|
||||
ScalingType mat2_scaling_type,
|
||||
at::blas::ScalingType mat2_scaling_type,
|
||||
const void* bias_ptr,
|
||||
ScalarType bias_dtype,
|
||||
void* result_ptr,
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
using at::cuda::blas::ScalingType;
|
||||
using at::blas::ScalingType;
|
||||
|
||||
enum class BlasOp {
|
||||
N = 0,
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
cudnnDataType_t getCudnnDataTypeFromScalarType(const at::ScalarType dtype) {
|
||||
@ -20,9 +22,10 @@ cudnnDataType_t getCudnnDataTypeFromScalarType(const at::ScalarType dtype) {
|
||||
} else if (dtype == at::kByte) {
|
||||
return CUDNN_DATA_UINT8;
|
||||
}
|
||||
std::string msg("getCudnnDataTypeFromScalarType() not supported for ");
|
||||
msg += toString(dtype);
|
||||
throw std::runtime_error(msg);
|
||||
TORCH_CHECK(false,
|
||||
"getCudnnDataTypeFromScalarType() not supported for ",
|
||||
toString(dtype)
|
||||
);
|
||||
}
|
||||
|
||||
cudnnDataType_t getCudnnDataType(const at::Tensor& tensor) {
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include <ATen/native/IndexKernel.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
||||
// NOLINTBEGIN(bugprone-unchecked-optional-access)
|
||||
@ -94,9 +95,10 @@ static std::vector<std::optional<Tensor>> batchIndices(
|
||||
if (index.has_value() && index->sym_numel() != 0) {
|
||||
const auto idx_bdim = indices_bdims[i];
|
||||
indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));
|
||||
if (index.value().dtype() == kBool && indices_bdims[i].has_value()) {
|
||||
throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask.");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!(index.value().dtype() == kBool) || !indices_bdims[i].has_value(),
|
||||
"vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask."
|
||||
);
|
||||
} else {
|
||||
indices_.push_back(index);
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
#include <ATen/functorch/Macros.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <optional>
|
||||
#include <bitset>
|
||||
#include <utility>
|
||||
@ -106,9 +107,10 @@ struct VmapInterpreterMeta {
|
||||
|
||||
template <typename T>
|
||||
friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
|
||||
if (json_t.batchSize_.is_heap_allocated()) {
|
||||
throw std::runtime_error("Serialization for heap-allocated SymInt is not implemented yet");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!json_t.batchSize_.is_heap_allocated(),
|
||||
"Serialization for heap-allocated SymInt is not implemented yet"
|
||||
);
|
||||
json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
|
||||
json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
|
||||
}
|
||||
@ -302,7 +304,7 @@ struct Interpreter {
|
||||
} else if (meta.contains("Functionalize")) {
|
||||
json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
|
||||
} else {
|
||||
throw std::runtime_error("unknown interpreter metadata type");
|
||||
TORCH_CHECK(false, "unknown interpreter metadata type");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include <ATen/functorch/BatchedTensorImpl.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/xnnpack/Engine.h>
|
||||
@ -108,9 +109,7 @@ Tensor binary_cross_entropy_with_logits_hack(
|
||||
}
|
||||
|
||||
Tensor trace_backward_decomp(const Tensor& grad, IntArrayRef sizes) {
|
||||
if (sizes.size() != 2) {
|
||||
throw std::runtime_error("expected matrix input");
|
||||
}
|
||||
TORCH_CHECK(sizes.size() == 2, "expected matrix input");
|
||||
auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
|
||||
auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
|
||||
// Workaround using index_put instead of yet unsupported index_fill_
|
||||
|
||||
@ -1157,103 +1157,103 @@ REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_SVE_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
REGISTER_SVE256_DISPATCH(cholesky_stub, &cholesky_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_SVE_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
REGISTER_SVE256_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel)
|
||||
REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_SVE_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
REGISTER_SVE256_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel)
|
||||
REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_SVE_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
REGISTER_SVE256_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel)
|
||||
REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_SVE_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
REGISTER_SVE256_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl)
|
||||
REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_SVE_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
REGISTER_SVE256_DISPATCH(orgqr_stub, &orgqr_kernel_impl)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel)
|
||||
REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_SVE_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
REGISTER_SVE256_DISPATCH(ormqr_stub, &ormqr_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel)
|
||||
REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_SVE_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
REGISTER_SVE256_DISPATCH(lstsq_stub, &lstsq_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel)
|
||||
REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_SVE_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
REGISTER_SVE256_DISPATCH(triangular_solve_stub, &triangular_solve_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel)
|
||||
REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_SVE_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
REGISTER_SVE256_DISPATCH(lu_factor_stub, &lu_factor_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel)
|
||||
REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_SVE_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
REGISTER_SVE256_DISPATCH(ldl_factor_stub, &ldl_factor_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel)
|
||||
REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_SVE_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
REGISTER_SVE256_DISPATCH(ldl_solve_stub, &ldl_solve_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel)
|
||||
REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_SVE_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
REGISTER_SVE256_DISPATCH(lu_solve_stub, &lu_solve_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel)
|
||||
REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_SVE_DISPATCH(svd_stub, &svd_kernel)
|
||||
REGISTER_SVE256_DISPATCH(svd_stub, &svd_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel)
|
||||
} // namespace at::native
|
||||
|
||||
@ -39,21 +39,19 @@ static CPUCapability compute_cpu_capability() {
|
||||
}
|
||||
#elif defined(HAVE_SVE_CPU_DEFINITION)
|
||||
int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW.
|
||||
if (envar == "sve") {
|
||||
// Select SVE capability based on the maximum SVE VL supported by the HW.
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
if (envar == "sve256") {
|
||||
if (sve_vl == 256) {
|
||||
#ifdef HAVE_ARM_BF16_CPU_DEFINITION
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
return CPUCapability::SVE256;
|
||||
}
|
||||
} else if (sve_vl == 128) {
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
return CPUCapability::SVE128;
|
||||
}
|
||||
} else {
|
||||
TORCH_WARN("SVE capability not available on hardware. Falling back to DEFAULT");
|
||||
return CPUCapability::DEFAULT;
|
||||
#endif
|
||||
}
|
||||
TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT");
|
||||
return CPUCapability::DEFAULT;
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
#ifdef HAVE_AVX512_CPU_DEFINITION
|
||||
if (envar == "avx512") {
|
||||
@ -115,11 +113,6 @@ static CPUCapability compute_cpu_capability() {
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
if (sve_vl == 128) { // Check for SVE128
|
||||
return CPUCapability::SVE128;
|
||||
}
|
||||
#endif
|
||||
// Return the default CPU capability.
|
||||
return CPUCapability::DEFAULT;
|
||||
}
|
||||
@ -154,9 +147,6 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, void *SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, void *SVE128
|
||||
#endif
|
||||
) {
|
||||
constexpr auto supported_devices = c10::array_of<c10::DeviceType>(
|
||||
c10::DeviceType::CPU,
|
||||
@ -194,9 +184,6 @@ DispatchResult DispatchStubImpl::try_get_call_ptr(
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, SVE128
|
||||
#endif
|
||||
);
|
||||
if (!std::holds_alternative<ErrorType>(result)) {
|
||||
@ -255,9 +242,6 @@ void* DispatchStubImpl::get_call_ptr(
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, void *SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, void *SVE128
|
||||
#endif
|
||||
) {
|
||||
|
||||
auto result = try_get_call_ptr(
|
||||
@ -282,10 +266,6 @@ void* DispatchStubImpl::get_call_ptr(
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
,
|
||||
SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
,
|
||||
SVE128
|
||||
#endif
|
||||
);
|
||||
if (std::holds_alternative<ErrorType>(result)) {
|
||||
@ -320,9 +300,6 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, void *SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, void *SVE128
|
||||
#endif
|
||||
){
|
||||
|
||||
@ -365,16 +342,6 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl(
|
||||
return DispatchResult(SVE256);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
if (capability >= static_cast<int>(CPUCapability::SVE128)) {
|
||||
if (C10_UNLIKELY(!SVE128)) {
|
||||
// dispatch to DEFAULT, since the SVE kernel is missing
|
||||
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
|
||||
} else {
|
||||
return DispatchResult(SVE128);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel;
|
||||
}
|
||||
@ -396,9 +363,6 @@ void* DispatchStubImpl::choose_cpu_impl(
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, void *SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, void *SVE128
|
||||
#endif
|
||||
) {
|
||||
auto capability = static_cast<int>(get_cpu_capability());
|
||||
(void)capability;
|
||||
@ -444,17 +408,6 @@ void* DispatchStubImpl::choose_cpu_impl(
|
||||
return SVE256;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
if (capability >= static_cast<int>(CPUCapability::SVE128)) {
|
||||
if (C10_UNLIKELY(!SVE128)) {
|
||||
// dispatch to DEFAULT, since the SVE kernel is missing
|
||||
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
|
||||
return DEFAULT;
|
||||
} else {
|
||||
return SVE128;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
|
||||
return DEFAULT;
|
||||
|
||||
@ -64,9 +64,8 @@ enum class CPUCapability {
|
||||
VSX = 1,
|
||||
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
|
||||
ZVECTOR = 1,
|
||||
#elif defined(HAVE_SVE_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
||||
SVE256 = 1,
|
||||
SVE128 = 2,
|
||||
#else
|
||||
AVX2 = 1,
|
||||
AVX512 = 2,
|
||||
@ -118,9 +117,6 @@ struct TORCH_API DispatchStubImpl {
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, void *SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, void *SVE128
|
||||
#endif
|
||||
);
|
||||
|
||||
@ -142,9 +138,6 @@ struct TORCH_API DispatchStubImpl {
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, void *SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, void *SVE128
|
||||
#endif
|
||||
);
|
||||
|
||||
@ -166,9 +159,6 @@ struct TORCH_API DispatchStubImpl {
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, void *SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, void *SVE128
|
||||
#endif
|
||||
);
|
||||
|
||||
@ -193,9 +183,6 @@ struct TORCH_API DispatchStubImpl {
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, void *SVE256
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, void *SVE128
|
||||
#endif
|
||||
);
|
||||
|
||||
@ -253,9 +240,6 @@ private:
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, reinterpret_cast<void*>(SVE256)
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, reinterpret_cast<void*>(SVE128)
|
||||
#endif
|
||||
)
|
||||
);
|
||||
@ -317,9 +301,6 @@ public:
|
||||
#endif
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
, reinterpret_cast<void*>(SVE256)
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
, reinterpret_cast<void*>(SVE128)
|
||||
#endif
|
||||
);
|
||||
if (std::holds_alternative<ErrorType>(result)){
|
||||
@ -344,9 +325,6 @@ public:
|
||||
#ifdef HAVE_SVE256_CPU_DEFINITION
|
||||
static TORCH_API FnPtr SVE256;
|
||||
#endif
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
static TORCH_API FnPtr SVE128;
|
||||
#endif
|
||||
private:
|
||||
DispatchStubImpl impl;
|
||||
};
|
||||
@ -454,12 +432,6 @@ struct RegisterPRIVATEUSE1Dispatch {
|
||||
#define REGISTER_SVE256_DISPATCH(name, fn)
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_SVE128_CPU_DEFINITION
|
||||
#define REGISTER_SVE128_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE128, fn)
|
||||
#else
|
||||
#define REGISTER_SVE128_DISPATCH(name, fn)
|
||||
#endif
|
||||
|
||||
// Macro to register the same kernel for all CPU arch types. This is useful
|
||||
// if a kernel does not benefit from being recompiled across different arch types.
|
||||
#define REGISTER_ALL_CPU_DISPATCH(name, fn) \
|
||||
@ -468,11 +440,6 @@ struct RegisterPRIVATEUSE1Dispatch {
|
||||
REGISTER_AVX2_DISPATCH(name, fn) \
|
||||
REGISTER_VSX_DISPATCH(name, fn) \
|
||||
REGISTER_ZVECTOR_DISPATCH(name, fn) \
|
||||
REGISTER_SVE256_DISPATCH(name, fn) \
|
||||
REGISTER_SVE128_DISPATCH(name, fn)
|
||||
|
||||
#define REGISTER_SVE_DISPATCH(name, fn) \
|
||||
REGISTER_SVE128_DISPATCH(name, fn) \
|
||||
REGISTER_SVE256_DISPATCH(name, fn)
|
||||
|
||||
#define REGISTER_NO_CPU_DISPATCH(name) \
|
||||
@ -515,7 +482,6 @@ struct RegisterPRIVATEUSE1Dispatch {
|
||||
// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
|
||||
// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
|
||||
// ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others.
|
||||
// ALSO_REGISTER_SVE128_DISPATCH should be used for ensuring SVE128 dispatch, among others.
|
||||
#ifdef CPU_CAPABILITY_AVX512
|
||||
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
|
||||
#else
|
||||
@ -523,7 +489,6 @@ struct RegisterPRIVATEUSE1Dispatch {
|
||||
#endif
|
||||
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
||||
#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
||||
#define ALSO_REGISTER_SVE128_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
||||
#endif
|
||||
} // namespace at::native
|
||||
|
||||
|
||||
@ -15,7 +15,11 @@ namespace at::native {
|
||||
|
||||
Scalar item(const Tensor& self) {
|
||||
auto numel = self.sym_numel();
|
||||
TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
|
||||
TORCH_SYM_CHECK(
|
||||
numel.sym_eq(1),
|
||||
"a Tensor with ",
|
||||
numel,
|
||||
" elements cannot be converted to Scalar");
|
||||
if (self.is_sparse()) {
|
||||
if (self._nnz() == 0) return Scalar(0);
|
||||
if (self.is_coalesced()) return at::_local_scalar_dense(self._values());
|
||||
|
||||
@ -466,7 +466,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cp
|
||||
REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel)
|
||||
|
||||
// offsets dispatches
|
||||
REGISTER_ARCH_DISPATCH(
|
||||
@ -477,7 +477,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cp
|
||||
REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel)
|
||||
|
||||
// Currently some computation is being duplicated across forward and backward.
|
||||
// TODO: Cache indices in forward pass to reuse in backward
|
||||
@ -548,7 +548,7 @@ REGISTER_VSX_DISPATCH(
|
||||
REGISTER_ZVECTOR_DISPATCH(
|
||||
_segment_reduce_lengths_backward_stub,
|
||||
&_segment_reduce_cpu_lengths_backward_kernel)
|
||||
REGISTER_SVE_DISPATCH(
|
||||
REGISTER_SVE256_DISPATCH(
|
||||
_segment_reduce_lengths_backward_stub,
|
||||
&_segment_reduce_cpu_lengths_backward_kernel)
|
||||
|
||||
@ -568,7 +568,7 @@ REGISTER_VSX_DISPATCH(
|
||||
REGISTER_ZVECTOR_DISPATCH(
|
||||
_segment_reduce_offsets_backward_stub,
|
||||
&_segment_reduce_cpu_offsets_backward_kernel)
|
||||
REGISTER_SVE_DISPATCH(
|
||||
REGISTER_SVE256_DISPATCH(
|
||||
_segment_reduce_offsets_backward_stub,
|
||||
&_segment_reduce_cpu_offsets_backward_kernel)
|
||||
|
||||
|
||||
@ -23,6 +23,14 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_cast_Byte_native.h>
|
||||
#include <ATen/ops/_cast_Char_native.h>
|
||||
#include <ATen/ops/_cast_Double_native.h>
|
||||
#include <ATen/ops/_cast_Float_native.h>
|
||||
#include <ATen/ops/_cast_Half_native.h>
|
||||
#include <ATen/ops/_cast_Int_native.h>
|
||||
#include <ATen/ops/_cast_Long_native.h>
|
||||
#include <ATen/ops/_cast_Short_native.h>
|
||||
#include <ATen/ops/_dim_arange_native.h>
|
||||
#include <ATen/ops/_efficientzerotensor_native.h>
|
||||
#include <ATen/ops/_empty_affine_quantized.h>
|
||||
|
||||
@ -406,7 +406,7 @@ scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void get_cubic_upsample_coefficients(
|
||||
static inline void get_cubic_upsample_coefficients(
|
||||
scalar_t coeffs[4],
|
||||
scalar_t t) {
|
||||
scalar_t A = -0.75;
|
||||
|
||||
@ -212,7 +212,7 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
|
||||
const vec::Vectorized<c10::Half>& b,
|
||||
const vec::Vectorized<float>& acc_low,
|
||||
const vec::Vectorized<float>& acc_high) {
|
||||
#if defined(__aarch64__) && ((defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)) || (defined(CPU_CAPABILITY_SVE128)))
|
||||
#if defined(__ARM_FEATURE_FP16_FML) && !defined(CPU_CAPABILITY_SVE)
|
||||
return std::make_pair(vfmlalq_low_f16(acc_low, a, b), vfmlalq_high_f16(acc_high, a, b));
|
||||
#else
|
||||
const auto [a_float_low, a_float_high] = convert_half_float(a);
|
||||
@ -233,7 +233,7 @@ std::pair<vec::Vectorized<float>, vec::Vectorized<float>> fmadd(
|
||||
|
||||
// Return a + b_low * c_low + b_high * c_high
|
||||
vec::Vectorized<float> fmadd(vec::Vectorized<float> a, vec::Vectorized<Half> b, vec::Vectorized<Half> c) {
|
||||
#if defined(__aarch64__) && ((defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)) || (defined(CPU_CAPABILITY_SVE128)))
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_FML) && !defined(__ARM_FEATURE_SVE)
|
||||
// NOTE: this instruction is an optional instruction in ARM v8.2 and
|
||||
// v8.3, but mandatory in v8.4 per
|
||||
// https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLAL--FMLAL2--vector---Floating-point-fused-Multiply-Add-Long-to-accumulator--vector--?lang=en
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -88,9 +88,9 @@ __global__ void compute_grad_weight_bags(
|
||||
const int64_t stride_warped) {
|
||||
|
||||
int64_t num_of_segments = *num_of_segments_ptr;
|
||||
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int id = gid / stride_warped;
|
||||
const int startFeature = gid % stride_warped;
|
||||
const int64_t gid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t id = gid / stride_warped;
|
||||
const int64_t startFeature = gid % stride_warped;
|
||||
if (startFeature >= stride) {
|
||||
return;
|
||||
}
|
||||
@ -134,9 +134,9 @@ __global__ void compute_grad_weight(
|
||||
|
||||
int64_t num_of_segments = *num_of_segments_ptr;
|
||||
using accscalar_t = acc_type<scalar_t, true>;
|
||||
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int id = gid / stride_warped;
|
||||
const int startFeature = gid % stride_warped;
|
||||
const int64_t gid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t id = gid / stride_warped;
|
||||
const int64_t startFeature = gid % stride_warped;
|
||||
if (startFeature >= stride) {
|
||||
return;
|
||||
}
|
||||
@ -167,9 +167,9 @@ __global__ void sum_and_scatter(
|
||||
|
||||
int64_t num_of_segments = *num_of_segments_ptr;
|
||||
int64_t num_of_partial_segments = *num_of_partial_segments_ptr;
|
||||
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int id = gid / stride_warped;
|
||||
const int startFeature = gid % stride_warped;
|
||||
const int64_t gid = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
const int64_t id = gid / stride_warped;
|
||||
const int64_t startFeature = gid % stride_warped;
|
||||
if (startFeature >= stride) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -710,6 +710,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
||||
dim3 block(warp_size, indices_per_block);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
dim3 new_grid_many_indices(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)),
|
||||
grid.y == 1 ? std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size))) : grid.y,
|
||||
grid.z);
|
||||
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
|
||||
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
|
||||
#define KERNEL_GRID new_grid
|
||||
@ -788,7 +791,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
||||
expandedValue.scalar_type(),
|
||||
"indexing_backward_many_indices",
|
||||
AT_WRAP([&] {
|
||||
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid, block, smem_dups_size, stream>>>(
|
||||
indexing_backward_kernel_many_indices<scalar_t, UNROLL><<<new_grid_many_indices, block, smem_dups_size, stream>>>(
|
||||
sorted_indices.const_data_ptr<int64_t>(),
|
||||
orig_indices.const_data_ptr<int64_t>(),
|
||||
expandedValue.const_data_ptr<scalar_t>(),
|
||||
|
||||
@ -488,15 +488,16 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
}
|
||||
}
|
||||
|
||||
int cat_dim = dimension;
|
||||
if (memory_format != c10::MemoryFormat::Contiguous) {
|
||||
switch (dimension) {
|
||||
switch (cat_dim) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
dimension = nDims - dimension;
|
||||
cat_dim = nDims - cat_dim;
|
||||
break;
|
||||
default:
|
||||
dimension--;
|
||||
cat_dim--;
|
||||
}
|
||||
}
|
||||
// Template Declarations for dim = 1, 2, 3, 4
|
||||
@ -505,23 +506,23 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
constexpr auto elems_per_vec = alignment / sizeof(scalar_t); \
|
||||
CatArrayBatchedCopy_vectorized<scalar_t, unsigned int, DIMS, batch_size, stride_size, alignment, elems_per_vec><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
(char*)data, catMetaData, kernelOutputParam, dimension, trailingSize);\
|
||||
(char*)data, catMetaData, kernelOutputParam, cat_dim, trailingSize);\
|
||||
} else if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\
|
||||
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_16><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
} else if (isContig && isAligned && sizeof(scalar_t) == 2) { \
|
||||
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_8><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
} else if (isContig) {\
|
||||
CatArrayBatchedCopy_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
} else {\
|
||||
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
}\
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
switch (nDims) {
|
||||
|
||||
@ -127,7 +127,7 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
|
||||
return diff == 0 ? 0 : uint32_t(Align) - diff;
|
||||
}
|
||||
|
||||
#if defined (__gfx90a__) || defined(__gfx942__)
|
||||
#if defined (__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#define CDNA2_OR_LATER 1
|
||||
#else
|
||||
#define CDNA2_OR_LATER 0
|
||||
@ -143,7 +143,7 @@ template<typename T, uint32_t Rank>
|
||||
using VecT = T __attribute__((ext_vector_type(Rank)));
|
||||
|
||||
static bool isCDNA2orLater(int index) {
|
||||
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942"}, index);
|
||||
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942", "gfx950"}, index);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
@ -341,16 +341,22 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename KeyType>
|
||||
struct MHAGraphCache {
|
||||
std::unordered_map<KeyType, T, ParamsWrapperHash<KeyType>> engine_cache;
|
||||
using KeyType = MHACacheKeyWrapper;
|
||||
using ValueType = std::unique_ptr<fe::graph::Graph>;
|
||||
using MapType =
|
||||
std::unordered_map<KeyType, ValueType, ParamsWrapperHash<KeyType>>;
|
||||
using iterator = typename MapType::iterator;
|
||||
using const_iterator = typename MapType::const_iterator;
|
||||
|
||||
MapType engine_cache;
|
||||
int count = 0;
|
||||
int hits = 0;
|
||||
|
||||
// no mutexes here as caches are now thread local for v8, can also return a
|
||||
// pointer to the Execution Plan if we know it will not be invalidated by
|
||||
// another thread
|
||||
T* find(const KeyType& key) {
|
||||
iterator find(const KeyType& key) {
|
||||
static bool flag =
|
||||
c10::utils::check_env("TORCH_CUDNN_SDPA_CACHE_DEBUG") == true;
|
||||
if (flag && count) {
|
||||
@ -363,15 +369,19 @@ struct MHAGraphCache {
|
||||
}
|
||||
count++;
|
||||
auto it = engine_cache.find(key);
|
||||
if (it == engine_cache.end()) {
|
||||
return nullptr;
|
||||
if (it != engine_cache.end()) {
|
||||
hits++;
|
||||
}
|
||||
hits++;
|
||||
return &(it->second);
|
||||
return it;
|
||||
}
|
||||
|
||||
void update(const KeyType& key, T& results) {
|
||||
engine_cache.insert_or_assign(key, std::move(results));
|
||||
const_iterator end() const {
|
||||
return engine_cache.end();
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
std::pair<iterator, bool> try_emplace(const KeyType& key, Args&&... args) {
|
||||
return engine_cache.try_emplace(key, std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
@ -380,16 +390,14 @@ struct MHAGraphCache {
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html
|
||||
// We also leak the caches to workaround potential teardown race issues.
|
||||
|
||||
auto& getMHAGraphCache_() {
|
||||
thread_local auto& instance =
|
||||
*new MHAGraphCache<std::shared_ptr<fe::graph::Graph>, MHACacheKeyWrapper>;
|
||||
return instance;
|
||||
MHAGraphCache& getMHAGraphCache_() {
|
||||
thread_local MHAGraphCache* instance{new MHAGraphCache()};
|
||||
return *instance;
|
||||
}
|
||||
|
||||
auto& getMHAGraphBackwardCache_() {
|
||||
thread_local auto& instance =
|
||||
*new MHAGraphCache<std::shared_ptr<fe::graph::Graph>, MHACacheKeyWrapper>;
|
||||
return instance;
|
||||
MHAGraphCache& getMHAGraphBackwardCache_() {
|
||||
thread_local MHAGraphCache* instance{new MHAGraphCache()};
|
||||
return *instance;
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -437,7 +445,7 @@ auto fixSizeOneDimStrideSDPA(
|
||||
|
||||
} // namespace
|
||||
|
||||
auto build_graph(
|
||||
std::unique_ptr<fe::graph::Graph> build_graph(
|
||||
int64_t b,
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
@ -461,7 +469,7 @@ auto build_graph(
|
||||
if (q.scalar_type() == kBFloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
auto mha_graph = std::make_shared<fe::graph::Graph>();
|
||||
auto mha_graph = std::make_unique<fe::graph::Graph>();
|
||||
// We're baking in float accumulation and scale types
|
||||
// in theory the graph may support other types, but they
|
||||
// have not been tested
|
||||
@ -531,15 +539,13 @@ auto build_graph(
|
||||
fe::graph::Tensor_attributes().set_uid(K).set_name("K"));
|
||||
auto V_ = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes().set_uid(V).set_name("V"));
|
||||
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
|
||||
if (attn_bias.has_value()) {
|
||||
bias =
|
||||
scaled_dot_product_flash_attention_options.set_bias(
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_uid(BIAS)
|
||||
.set_name("bias")
|
||||
.set_dim(attn_bias.value().sizes().vec())
|
||||
.set_stride(attn_bias.value().strides().vec()));
|
||||
scaled_dot_product_flash_attention_options.set_bias(bias.value());
|
||||
.set_stride(attn_bias.value().strides().vec())));
|
||||
}
|
||||
|
||||
auto [O_, Stats] =
|
||||
@ -640,7 +646,7 @@ auto build_graph(
|
||||
return mha_graph;
|
||||
}
|
||||
|
||||
auto build_graph_nestedtensor(
|
||||
std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor(
|
||||
int64_t b,
|
||||
int64_t h_q,
|
||||
int64_t h_k,
|
||||
@ -668,7 +674,7 @@ auto build_graph_nestedtensor(
|
||||
if (q.scalar_type() == kBFloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
auto mha_graph = std::make_shared<fe::graph::Graph>();
|
||||
auto mha_graph = std::make_unique<fe::graph::Graph>();
|
||||
// We're baking in float accumulation and scale types
|
||||
// in theory the graph may support other types, but they
|
||||
// have not been tested
|
||||
@ -766,18 +772,16 @@ auto build_graph_nestedtensor(
|
||||
v_strides[strideidx0],
|
||||
v_strides[strideidx1],
|
||||
v_strides[strideidx2]}));
|
||||
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
|
||||
if (attn_bias.has_value()) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
|
||||
bias =
|
||||
scaled_dot_product_flash_attention_options.set_bias(
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_uid(BIAS)
|
||||
.set_name("bias")
|
||||
.set_dim(attn_bias.value().sizes().vec())
|
||||
.set_stride(attn_bias.value().strides().vec()));
|
||||
scaled_dot_product_flash_attention_options.set_bias(bias.value());
|
||||
.set_stride(attn_bias.value().strides().vec())));
|
||||
}
|
||||
auto RAG_Q_OFF_ =
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
@ -847,7 +851,7 @@ auto build_graph_nestedtensor(
|
||||
return mha_graph;
|
||||
}
|
||||
|
||||
auto build_graph_backward(
|
||||
std::unique_ptr<fe::graph::Graph> build_graph_backward(
|
||||
int64_t b,
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
@ -874,7 +878,7 @@ auto build_graph_backward(
|
||||
if (q.scalar_type() == kBFloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
auto mha_graph = std::make_shared<fe::graph::Graph>();
|
||||
auto mha_graph = std::make_unique<fe::graph::Graph>();
|
||||
// We're baking in float accumulation and scale types
|
||||
// in theory the graph may support other types, but they
|
||||
// have not been tested
|
||||
@ -919,15 +923,13 @@ auto build_graph_backward(
|
||||
fe::graph::Tensor_attributes().set_uid(K).set_name("K"));
|
||||
auto V_ = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes().set_uid(V).set_name("V"));
|
||||
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
|
||||
if (attn_bias.has_value()) {
|
||||
bias =
|
||||
sdpa_backward_options.set_bias(
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_uid(BIAS)
|
||||
.set_name("bias")
|
||||
.set_dim(attn_bias.value().sizes().vec())
|
||||
.set_stride(attn_bias.value().strides().vec()));
|
||||
sdpa_backward_options.set_bias(bias.value());
|
||||
.set_stride(attn_bias.value().strides().vec())));
|
||||
}
|
||||
if (dropout_probability != 0.0f) {
|
||||
auto seed = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
@ -1061,7 +1063,7 @@ auto build_graph_backward(
|
||||
return mha_graph;
|
||||
}
|
||||
|
||||
auto build_graph_backward_nestedtensor(
|
||||
std::unique_ptr<fe::graph::Graph> build_graph_backward_nestedtensor(
|
||||
int64_t b,
|
||||
int64_t h_q,
|
||||
int64_t h_k,
|
||||
@ -1092,7 +1094,7 @@ auto build_graph_backward_nestedtensor(
|
||||
if (q.scalar_type() == kBFloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
auto mha_graph = std::make_shared<fe::graph::Graph>();
|
||||
auto mha_graph = std::make_unique<fe::graph::Graph>();
|
||||
// We're baking in float accumulation and scale types
|
||||
// in theory the graph may support other types, but they
|
||||
// have not been tested
|
||||
@ -1195,18 +1197,16 @@ auto build_graph_backward_nestedtensor(
|
||||
o_strides[strideidx1],
|
||||
o_strides[strideidx2]}));
|
||||
|
||||
std::optional<std::shared_ptr<fe::graph::Tensor_attributes>> bias;
|
||||
if (attn_bias.has_value()) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
|
||||
bias =
|
||||
sdpa_backward_options.set_bias(
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_uid(BIAS)
|
||||
.set_name("bias")
|
||||
.set_dim(attn_bias.value().sizes().vec())
|
||||
.set_stride(attn_bias.value().strides().vec()));
|
||||
sdpa_backward_options.set_bias(bias.value());
|
||||
.set_stride(attn_bias.value().strides().vec())));
|
||||
}
|
||||
auto RAG_Q_OFF_ =
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
@ -1378,7 +1378,7 @@ void run_cudnn_SDP_fprop(
|
||||
// NB: The key initialization will round up sequence length, stride data etc.
|
||||
// if use_ragged_in_dense is enabled (to allow multiple sequence lengths to
|
||||
// reuse the same cached value/graph)
|
||||
auto key = MHACacheKeyWrapper(
|
||||
MHACacheKeyWrapper key(
|
||||
b,
|
||||
h,
|
||||
s_q,
|
||||
@ -1393,12 +1393,9 @@ void run_cudnn_SDP_fprop(
|
||||
is_causal,
|
||||
return_softmaxstats,
|
||||
false);
|
||||
auto graph_ptr = getMHAGraphCache_().find(key);
|
||||
std::shared_ptr<fe::graph::Graph> mha_graph;
|
||||
if (graph_ptr) {
|
||||
mha_graph = *graph_ptr;
|
||||
} else {
|
||||
mha_graph = build_graph(
|
||||
auto [cache_it, not_found] = getMHAGraphCache_().try_emplace(key, nullptr);
|
||||
if (not_found) {
|
||||
cache_it->second = build_graph(
|
||||
b,
|
||||
h,
|
||||
s_q,
|
||||
@ -1419,39 +1416,39 @@ void run_cudnn_SDP_fprop(
|
||||
_dropoutoffset,
|
||||
handle);
|
||||
}
|
||||
const fe::graph::Graph& mha_graph = *cache_it->second;
|
||||
std::unordered_map<int64_t, void*> variant_pack = {
|
||||
{Q, q.data_ptr()},
|
||||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{Q, q.mutable_data_ptr()},
|
||||
{K, k.mutable_data_ptr()},
|
||||
{V, v.mutable_data_ptr()},
|
||||
{SCALE, &scaling_factor},
|
||||
{O, o.data_ptr()}};
|
||||
{O, o.mutable_data_ptr()}};
|
||||
if (return_softmaxstats) {
|
||||
variant_pack[LSE] = softmaxstats.data_ptr();
|
||||
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
|
||||
}
|
||||
if (attn_bias.has_value()) {
|
||||
variant_pack[BIAS] = attn_bias.value().data_ptr();
|
||||
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
|
||||
}
|
||||
if (dropout_probability != 0.0f) {
|
||||
variant_pack[SEED] = _dropoutseed.data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
|
||||
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
|
||||
}
|
||||
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
|
||||
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
|
||||
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
|
||||
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
|
||||
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
|
||||
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
|
||||
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
|
||||
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
|
||||
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
|
||||
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
|
||||
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
|
||||
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
|
||||
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
|
||||
if (return_softmaxstats) {
|
||||
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
|
||||
}
|
||||
}
|
||||
auto workspace_size = mha_graph->get_workspace_size();
|
||||
auto workspace_size = mha_graph.get_workspace_size();
|
||||
auto workspace_ptr =
|
||||
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
|
||||
TORCH_CHECK(
|
||||
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
getMHAGraphCache_().update(key, mha_graph);
|
||||
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
}
|
||||
|
||||
void run_cudnn_SDP_fprop_nestedtensor(
|
||||
@ -1491,7 +1488,7 @@ void run_cudnn_SDP_fprop_nestedtensor(
|
||||
softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat));
|
||||
}
|
||||
|
||||
auto key = MHACacheKeyWrapper(
|
||||
MHACacheKeyWrapper key(
|
||||
b,
|
||||
h_q,
|
||||
s_q, // max-seqlen-q
|
||||
@ -1506,13 +1503,12 @@ void run_cudnn_SDP_fprop_nestedtensor(
|
||||
is_causal,
|
||||
return_softmaxstats,
|
||||
true);
|
||||
auto graph_ptr = getMHAGraphCache_().find(key);
|
||||
std::shared_ptr<fe::graph::Graph> mha_graph;
|
||||
|
||||
if (graph_ptr) {
|
||||
mha_graph = *graph_ptr;
|
||||
} else {
|
||||
mha_graph = build_graph_nestedtensor(
|
||||
MHAGraphCache& cache = getMHAGraphCache_();
|
||||
auto cache_it = cache.find(key);
|
||||
std::unique_ptr<fe::graph::Graph> mha_graph_storage;
|
||||
if (cache_it == cache.end()) {
|
||||
mha_graph_storage = build_graph_nestedtensor(
|
||||
b,
|
||||
h_q,
|
||||
h_k,
|
||||
@ -1537,40 +1533,44 @@ void run_cudnn_SDP_fprop_nestedtensor(
|
||||
dropoutoffset,
|
||||
handle);
|
||||
}
|
||||
const fe::graph::Graph& mha_graph =
|
||||
mha_graph_storage ? *mha_graph_storage : *cache_it->second;
|
||||
|
||||
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
|
||||
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
|
||||
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
|
||||
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
|
||||
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
|
||||
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
|
||||
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
|
||||
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
|
||||
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
|
||||
auto rag_stats_off = cum_seqlen_q.mul(h_q);
|
||||
std::unordered_map<int64_t, void*> variant_pack = {
|
||||
{Q, q.data_ptr()},
|
||||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{Q, q.mutable_data_ptr()},
|
||||
{K, k.mutable_data_ptr()},
|
||||
{V, v.mutable_data_ptr()},
|
||||
{SCALE, &scaling_factor},
|
||||
{O, o.data_ptr()},
|
||||
{RAG_Q_OFF, rag_q_off.data_ptr()},
|
||||
{RAG_O_OFF, rag_q_off.data_ptr()},
|
||||
{RAG_K_OFF, rag_k_off.data_ptr()},
|
||||
{RAG_V_OFF, rag_v_off.data_ptr()},
|
||||
{SEQ_LEN_Q, seqlen_q.data_ptr()},
|
||||
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
|
||||
{O, o.mutable_data_ptr()},
|
||||
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
|
||||
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
|
||||
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
|
||||
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
|
||||
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
|
||||
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
|
||||
if (return_softmaxstats) {
|
||||
variant_pack[LSE] = softmaxstats.data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr();
|
||||
variant_pack[LSE] = softmaxstats.mutable_data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_stats_off.mutable_data_ptr();
|
||||
}
|
||||
if (dropout_probability != 0.0f) {
|
||||
variant_pack[SEED] = dropoutseed.data_ptr();
|
||||
variant_pack[OFFSET] = dropoutoffset.data_ptr();
|
||||
variant_pack[SEED] = dropoutseed.mutable_data_ptr();
|
||||
variant_pack[OFFSET] = dropoutoffset.mutable_data_ptr();
|
||||
}
|
||||
if (attn_bias.has_value()) {
|
||||
TORCH_CHECK("bias not supported with nestedtensor");
|
||||
}
|
||||
auto workspace_size = mha_graph->get_workspace_size();
|
||||
auto workspace_size = mha_graph.get_workspace_size();
|
||||
auto workspace_ptr =
|
||||
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
|
||||
TORCH_CHECK(
|
||||
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
}
|
||||
|
||||
void run_cudnn_SDP_bprop(
|
||||
@ -1652,7 +1652,7 @@ void run_cudnn_SDP_bprop(
|
||||
}
|
||||
|
||||
cudnnHandle_t handle = getCudnnHandle();
|
||||
auto key = MHACacheKeyWrapper(
|
||||
MHACacheKeyWrapper key(
|
||||
b,
|
||||
h,
|
||||
s_q,
|
||||
@ -1667,12 +1667,10 @@ void run_cudnn_SDP_bprop(
|
||||
is_causal,
|
||||
true,
|
||||
false);
|
||||
auto graph_backward_ptr = getMHAGraphBackwardCache_().find(key);
|
||||
std::shared_ptr<fe::graph::Graph> mha_graph;
|
||||
if (graph_backward_ptr) {
|
||||
mha_graph = *graph_backward_ptr;
|
||||
} else {
|
||||
mha_graph = build_graph_backward(
|
||||
auto [cache_it, not_found] =
|
||||
getMHAGraphBackwardCache_().try_emplace(key, nullptr);
|
||||
if (not_found) {
|
||||
cache_it->second = build_graph_backward(
|
||||
b,
|
||||
h,
|
||||
s_q,
|
||||
@ -1696,43 +1694,44 @@ void run_cudnn_SDP_bprop(
|
||||
_dropoutoffset,
|
||||
handle);
|
||||
}
|
||||
const fe::graph::Graph& mha_graph = *cache_it->second;
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack = {
|
||||
// inputs
|
||||
{Q, q.data_ptr()},
|
||||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{O, o.data_ptr()},
|
||||
{DO, dO_.data_ptr()},
|
||||
{LSE, softmaxstats.data_ptr()},
|
||||
{Q, q.mutable_data_ptr()},
|
||||
{K, k.mutable_data_ptr()},
|
||||
{V, v.mutable_data_ptr()},
|
||||
{O, o.mutable_data_ptr()},
|
||||
{DO, dO_.mutable_data_ptr()},
|
||||
{LSE, softmaxstats.mutable_data_ptr()},
|
||||
// outputs
|
||||
{DQ, dQ.data_ptr()},
|
||||
{DK, dK.data_ptr()},
|
||||
{DV, dV.data_ptr()},
|
||||
{DQ, dQ.mutable_data_ptr()},
|
||||
{DK, dK.mutable_data_ptr()},
|
||||
{DV, dV.mutable_data_ptr()},
|
||||
{SCALE, &scaling_factor}};
|
||||
if (dropout_probability != 0.0f) {
|
||||
variant_pack[SEED] = _dropoutseed.data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
|
||||
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
|
||||
}
|
||||
if (attn_bias.has_value()) {
|
||||
variant_pack[BIAS] = attn_bias.value().data_ptr();
|
||||
variant_pack[BIAS] = attn_bias.value().mutable_data_ptr();
|
||||
}
|
||||
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
|
||||
variant_pack[SEQ_LEN_Q] = seqlen_q.data_ptr();
|
||||
variant_pack[SEQ_LEN_KV] = seqlen_kv.data_ptr();
|
||||
variant_pack[RAG_Q_OFF] = rag_off_q.data_ptr();
|
||||
variant_pack[RAG_K_OFF] = rag_off_k.data_ptr();
|
||||
variant_pack[RAG_V_OFF] = rag_off_v.data_ptr();
|
||||
variant_pack[RAG_O_OFF] = rag_off_o.data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_off_lse.data_ptr();
|
||||
variant_pack[SEQ_LEN_Q] = seqlen_q.mutable_data_ptr();
|
||||
variant_pack[SEQ_LEN_KV] = seqlen_kv.mutable_data_ptr();
|
||||
variant_pack[RAG_Q_OFF] = rag_off_q.mutable_data_ptr();
|
||||
variant_pack[RAG_K_OFF] = rag_off_k.mutable_data_ptr();
|
||||
variant_pack[RAG_V_OFF] = rag_off_v.mutable_data_ptr();
|
||||
variant_pack[RAG_O_OFF] = rag_off_o.mutable_data_ptr();
|
||||
variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr();
|
||||
}
|
||||
|
||||
auto workspace_size = mha_graph->get_workspace_size();
|
||||
auto workspace_size = mha_graph.get_workspace_size();
|
||||
auto workspace_ptr =
|
||||
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
|
||||
TORCH_CHECK(!workspace_size || workspace_ptr.get());
|
||||
TORCH_CHECK(
|
||||
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
getMHAGraphBackwardCache_().update(key, mha_graph);
|
||||
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
}
|
||||
|
||||
void run_cudnn_SDP_bprop_nestedtensor(
|
||||
@ -1775,9 +1774,10 @@ void run_cudnn_SDP_bprop_nestedtensor(
|
||||
|
||||
auto seqlen_q = at::diff(cum_seqlen_q, 1, 0);
|
||||
auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0);
|
||||
auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk);
|
||||
auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v);
|
||||
auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v);
|
||||
auto rag_q_off = cum_seqlen_q.mul(q.stride(-3));
|
||||
auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
|
||||
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
|
||||
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
|
||||
auto rag_stats_off = cum_seqlen_q.mul(h_q);
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
@ -1791,7 +1791,7 @@ void run_cudnn_SDP_bprop_nestedtensor(
|
||||
|
||||
cudnnHandle_t handle = getCudnnHandle();
|
||||
|
||||
auto key = MHACacheKeyWrapper(
|
||||
MHACacheKeyWrapper key(
|
||||
b,
|
||||
h_q,
|
||||
s_q, // max-seqlen-q
|
||||
@ -1806,13 +1806,12 @@ void run_cudnn_SDP_bprop_nestedtensor(
|
||||
is_causal,
|
||||
true,
|
||||
true);
|
||||
auto graph_ptr = getMHAGraphCache_().find(key);
|
||||
std::shared_ptr<fe::graph::Graph> mha_graph;
|
||||
|
||||
if (graph_ptr) {
|
||||
mha_graph = *graph_ptr;
|
||||
} else {
|
||||
mha_graph = build_graph_backward_nestedtensor(
|
||||
MHAGraphCache& cache = getMHAGraphCache_();
|
||||
auto cache_it = cache.find(key);
|
||||
std::unique_ptr<fe::graph::Graph> mha_graph_storage;
|
||||
if (cache_it == cache.end()) {
|
||||
mha_graph_storage = build_graph_backward_nestedtensor(
|
||||
b,
|
||||
h_q,
|
||||
h_k,
|
||||
@ -1840,41 +1839,43 @@ void run_cudnn_SDP_bprop_nestedtensor(
|
||||
dropoutoffset,
|
||||
handle);
|
||||
}
|
||||
const fe::graph::Graph& mha_graph =
|
||||
mha_graph_storage ? *mha_graph_storage : *cache_it->second;
|
||||
|
||||
std::unordered_map<int64_t, void*> variant_pack = {
|
||||
// inputs
|
||||
{Q, q.data_ptr()},
|
||||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{O, o.data_ptr()},
|
||||
{DO, dO_.data_ptr()},
|
||||
{LSE, softmaxstats.data_ptr()},
|
||||
{Q, q.mutable_data_ptr()},
|
||||
{K, k.mutable_data_ptr()},
|
||||
{V, v.mutable_data_ptr()},
|
||||
{O, o.mutable_data_ptr()},
|
||||
{DO, dO_.mutable_data_ptr()},
|
||||
{LSE, softmaxstats.mutable_data_ptr()},
|
||||
// outputs
|
||||
{DQ, dQ.data_ptr()},
|
||||
{DK, dK.data_ptr()},
|
||||
{DV, dV.data_ptr()},
|
||||
{DQ, dQ.mutable_data_ptr()},
|
||||
{DK, dK.mutable_data_ptr()},
|
||||
{DV, dV.mutable_data_ptr()},
|
||||
{SCALE, &scaling_factor},
|
||||
{RAG_Q_OFF, rag_q_off.data_ptr()},
|
||||
{RAG_O_OFF, rag_q_off.data_ptr()},
|
||||
{RAG_K_OFF, rag_k_off.data_ptr()},
|
||||
{RAG_V_OFF, rag_v_off.data_ptr()},
|
||||
{RAG_LSE_OFF, rag_stats_off.data_ptr()},
|
||||
{SEQ_LEN_Q, seqlen_q.data_ptr()},
|
||||
{SEQ_LEN_KV, seqlen_kv.data_ptr()}};
|
||||
{RAG_Q_OFF, rag_q_off.mutable_data_ptr()},
|
||||
{RAG_O_OFF, rag_o_off.mutable_data_ptr()},
|
||||
{RAG_K_OFF, rag_k_off.mutable_data_ptr()},
|
||||
{RAG_V_OFF, rag_v_off.mutable_data_ptr()},
|
||||
{RAG_LSE_OFF, rag_stats_off.mutable_data_ptr()},
|
||||
{SEQ_LEN_Q, seqlen_q.mutable_data_ptr()},
|
||||
{SEQ_LEN_KV, seqlen_kv.mutable_data_ptr()}};
|
||||
if (dropout_probability != 0.0f) {
|
||||
variant_pack[SEED] = _dropoutseed.data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.data_ptr();
|
||||
variant_pack[SEED] = _dropoutseed.mutable_data_ptr();
|
||||
variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr();
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!attn_bias.has_value(),
|
||||
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
|
||||
|
||||
auto workspace_size = mha_graph->get_workspace_size();
|
||||
auto workspace_size = mha_graph.get_workspace_size();
|
||||
auto workspace_ptr =
|
||||
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
|
||||
TORCH_CHECK(!workspace_size || workspace_ptr.get());
|
||||
TORCH_CHECK(
|
||||
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
||||
@ -165,7 +165,7 @@ REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_co
|
||||
REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
REGISTER_SVE_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
REGISTER_SVE256_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
|
||||
|
||||
// _out variants can be shared between PocketFFT and MKL
|
||||
Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
|
||||
|
||||
@ -116,6 +116,8 @@ class MetalShaderLibrary {
|
||||
std::vector<std::string> getFunctionNames();
|
||||
std::shared_ptr<MetalKernelFunction> getKernelFunction(
|
||||
const std::string& name);
|
||||
// Returns a raw pointer to the kernel function for use in C APIs
|
||||
MetalKernelFunction* getCachedKernelFunctionPtr(const std::string& name);
|
||||
inline MTLComputePipelineState_t getPipelineStateForFunc(
|
||||
const std::string& fname) {
|
||||
return getLibraryPipelineState(getLibrary(), fname).first;
|
||||
@ -164,6 +166,9 @@ class MetalShaderLibrary {
|
||||
std::string,
|
||||
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
|
||||
cplMap;
|
||||
// Cache for kernel functions returned by getCachedKernelFunctionPtr
|
||||
std::unordered_map<std::string, std::unique_ptr<MetalKernelFunction>>
|
||||
kernelCache;
|
||||
};
|
||||
|
||||
class DynamicMetalShaderLibrary : public MetalShaderLibrary {
|
||||
|
||||
@ -917,6 +917,22 @@ std::shared_ptr<MetalKernelFunction> MetalShaderLibrary::getKernelFunction(const
|
||||
return std::make_shared<MetalKernelFunction>(cpl, func);
|
||||
}
|
||||
|
||||
MetalKernelFunction* MetalShaderLibrary::getCachedKernelFunctionPtr(const std::string& name) {
|
||||
// Check if kernel is already cached
|
||||
auto it = kernelCache.find(name);
|
||||
if (it != kernelCache.end()) {
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Create new kernel function and cache it
|
||||
auto [cpl, func] = getLibraryPipelineState(getLibrary(), name);
|
||||
auto kernel = std::make_unique<MetalKernelFunction>(cpl, func);
|
||||
MetalKernelFunction* raw_ptr = kernel.get();
|
||||
kernelCache[name] = std::move(kernel);
|
||||
|
||||
return raw_ptr;
|
||||
}
|
||||
|
||||
class BundledShaderLibary : public MetalShaderLibrary {
|
||||
public:
|
||||
BundledShaderLibary() : MetalShaderLibrary("") {}
|
||||
|
||||
@ -5,6 +5,38 @@
|
||||
# representing ScalarType's. They are now superseded by usage of
|
||||
# `aten::to()`. The ops remain here for backward compatibility purposes.
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# DEPRECATED. DO NOT USE
|
||||
- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
# Computes the gradient of current tensor w.r.t. graph leaves.
|
||||
- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
||||
manual_cpp_binding: True
|
||||
@ -7125,6 +7157,7 @@
|
||||
CUDA: _scaled_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
|
||||
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
@ -7132,6 +7165,16 @@
|
||||
CUDA: _scaled_mm_out_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_mm_cuda_v2
|
||||
|
||||
- func: _scaled_mm_v2.out(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_mm_cuda_v2_out
|
||||
|
||||
|
||||
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
|
||||
@ -48,8 +48,8 @@ std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
|
||||
int64_t axis,
|
||||
int64_t quant_min,
|
||||
int64_t quant_max) {
|
||||
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
|
||||
"Scale must be Float, found ", scale.scalar_type());
|
||||
TORCH_CHECK(scale.scalar_type() == ScalarType::Float || scale.scalar_type() == at::kBFloat16,
|
||||
"Scale must be Float or BFloat16, found ", scale.scalar_type());
|
||||
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Int || zero_point.scalar_type() == ScalarType::Float || zero_point.scalar_type() == ScalarType::Half,
|
||||
"Zero-point must be Int32, Float or Half, found ", zero_point.scalar_type());
|
||||
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
|
||||
|
||||
@ -27,6 +27,6 @@ REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel)
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -161,19 +161,19 @@ REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_
|
||||
REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel)
|
||||
|
||||
REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_SVE_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel)
|
||||
}
|
||||
|
||||
@ -448,7 +448,7 @@ REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_SVE_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
|
||||
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta)
|
||||
|
||||
int64_t _fused_sdp_choice_meta(
|
||||
|
||||
@ -637,13 +637,7 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) {
|
||||
TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled.");
|
||||
}
|
||||
return false;
|
||||
} else if (has_for_nested_inputs(params) && (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad())) {
|
||||
if (debug) {
|
||||
TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto dprop = at::cuda::getCurrentDeviceProperties();
|
||||
// Check that the input is nested
|
||||
if (!(dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) {
|
||||
|
||||
@ -37,7 +37,7 @@ class Benchmark(BenchmarkBase):
|
||||
def f(a, b):
|
||||
xs = b.tolist()
|
||||
for x in xs:
|
||||
torch._check_is_size(x)
|
||||
torch._check(x >= 0)
|
||||
torch._check(x <= self.N)
|
||||
return a.split(xs)
|
||||
|
||||
|
||||
@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
|
||||
// where we would like to support composite implicit kernels but not
|
||||
// explicit kernels therefore we manually add the key to the
|
||||
// math_dispatch_keyset
|
||||
DispatchKeySet{DispatchKey::NestedTensor};
|
||||
DispatchKeySet{DispatchKey::NestedTensor} |
|
||||
// Functionalize should always reuse CompositeImplicit decomps.
|
||||
DispatchKeySet{DispatchKey::Functionalize};
|
||||
|
||||
constexpr DispatchKeySet nested_dispatch_keyset =
|
||||
DispatchKeySet(
|
||||
|
||||
@ -130,14 +130,6 @@ int64_t SymInt::guard_int(const char* file, int64_t line) const {
|
||||
}
|
||||
}
|
||||
|
||||
bool SymInt::expect_size(const char* file, int64_t line) const {
|
||||
if (auto ma = maybe_as_int()) {
|
||||
return *ma >= 0;
|
||||
} else {
|
||||
return toSymNodeImplUnowned()->expect_size(file, line);
|
||||
}
|
||||
}
|
||||
|
||||
SymInt operator-(const SymInt& s) {
|
||||
if (auto ma = s.maybe_as_int()) {
|
||||
const auto val = *ma;
|
||||
|
||||
@ -153,14 +153,6 @@ class C10_API SymInt {
|
||||
// number can be used to diagnose overspecialization.
|
||||
int64_t guard_int(const char* file, int64_t line) const;
|
||||
|
||||
// Insert a guard that this SymInt must be size-like, returning true if
|
||||
// the integer actually is >= 0. Unlike manually performing a >= 0 test,
|
||||
// if the SymInt in question is an unbacked SymInt (or, potentially in the
|
||||
// future, if it contains unbacked SymInts), we will also treat the
|
||||
// unbacked SymInt as statically testing >= 2 (which will prevent us from
|
||||
// choking on, e.g., contiguity checks.)
|
||||
bool expect_size(const char* file, int64_t line) const;
|
||||
|
||||
// Distinguish actual symbolic values from constants stored on the heap
|
||||
bool is_symbolic() const {
|
||||
return is_heap_allocated() &&
|
||||
|
||||
@ -210,11 +210,6 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
||||
// with a better implementation!
|
||||
return guard_bool(file, line);
|
||||
}
|
||||
virtual bool expect_size(const char* file, int64_t line) {
|
||||
// No improvement for unbacked SymInts by default, replace this
|
||||
// with a better implementation!
|
||||
return ge(wrap_int(0))->guard_bool(file, line);
|
||||
}
|
||||
virtual int64_t int_() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
}
|
||||
|
||||
@ -108,12 +108,15 @@ void* alloc_cpu(size_t nbytes) {
|
||||
"DefaultCPUAllocator: not enough memory: you tried to allocate ",
|
||||
nbytes,
|
||||
" bytes.");
|
||||
#elif defined(_MSC_VER)
|
||||
#ifdef USE_MIMALLOC
|
||||
#elif defined(USE_MIMALLOC)
|
||||
data = mi_malloc_aligned(nbytes, gAlignment);
|
||||
#else
|
||||
CAFFE_ENFORCE(
|
||||
data,
|
||||
"DefaultCPUAllocator: not enough memory: you tried to allocate ",
|
||||
nbytes,
|
||||
" bytes.");
|
||||
#elif defined(_MSC_VER)
|
||||
data = _aligned_malloc(nbytes, gAlignment);
|
||||
#endif
|
||||
CAFFE_ENFORCE(
|
||||
data,
|
||||
"DefaultCPUAllocator: not enough memory: you tried to allocate ",
|
||||
@ -160,12 +163,10 @@ void* alloc_cpu(size_t nbytes) {
|
||||
}
|
||||
|
||||
void free_cpu(void* data) {
|
||||
#ifdef _MSC_VER
|
||||
#ifdef USE_MIMALLOC
|
||||
mi_free(data);
|
||||
#else
|
||||
#elif defined(_MSC_VER)
|
||||
_aligned_free(data);
|
||||
#endif
|
||||
#else
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
|
||||
free(data);
|
||||
|
||||
@ -638,11 +638,11 @@ struct ExpandableSegment {
|
||||
return *stream_;
|
||||
}
|
||||
|
||||
size_t getMappedSize() {
|
||||
size_t getMappedSize() const {
|
||||
return mapped_size_;
|
||||
}
|
||||
|
||||
size_t getSegmentSize() {
|
||||
size_t getSegmentSize() const {
|
||||
return segment_size_;
|
||||
}
|
||||
|
||||
@ -799,11 +799,11 @@ struct ExpandableSegment {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t getMappedSize() {
|
||||
size_t getMappedSize() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t getSegmentSize() {
|
||||
size_t getSegmentSize() const {
|
||||
return 0;
|
||||
}
|
||||
void addPeer(c10::DeviceIndex device) {}
|
||||
@ -824,14 +824,14 @@ struct BlockState {
|
||||
// maintain invariant that event_count == 0 ;
|
||||
// history will be left alone in checkpoint
|
||||
|
||||
BlockState(Block* block);
|
||||
explicit BlockState(Block* block);
|
||||
};
|
||||
|
||||
struct SegmentState {
|
||||
std::vector<BlockState> blocks;
|
||||
bool is_small = false;
|
||||
|
||||
SegmentState(Block* head);
|
||||
explicit SegmentState(Block* head);
|
||||
};
|
||||
|
||||
struct PrivatePoolState : AllocatorState {
|
||||
@ -949,7 +949,7 @@ class EventPool {
|
||||
|
||||
// CUDA graphs helper
|
||||
struct PrivatePool {
|
||||
PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr)
|
||||
explicit PrivatePool(MempoolId_t id, CUDAAllocator* allocator = nullptr)
|
||||
: id(std::move(id)),
|
||||
allocator_(allocator),
|
||||
large_blocks(/*small=*/false, this),
|
||||
@ -1078,7 +1078,7 @@ class RingBuffer {
|
||||
}
|
||||
}
|
||||
|
||||
void getEntries(std::vector<T>& result) {
|
||||
void getEntries(std::vector<T>& result) const {
|
||||
std::lock_guard<std::mutex> lk(alloc_trace_lock);
|
||||
result.reserve(alloc_trace->size());
|
||||
result.insert(
|
||||
@ -1106,7 +1106,7 @@ class RingBuffer {
|
||||
|
||||
// Both alloc_trace and alloc_trace_next needs to be used
|
||||
// under alloc_trace_lock.
|
||||
std::mutex alloc_trace_lock;
|
||||
mutable std::mutex alloc_trace_lock;
|
||||
size_t alloc_trace_next = 0;
|
||||
std::vector<T>*
|
||||
alloc_trace; // pointer because we need to intentionally leak this on
|
||||
@ -1299,7 +1299,7 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
bool isHistoryEnabled() {
|
||||
bool isHistoryEnabled() const {
|
||||
return record_history;
|
||||
}
|
||||
|
||||
@ -1315,7 +1315,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
bool checkPoolLiveAllocations(
|
||||
MempoolId_t mempool_id,
|
||||
const std::unordered_set<void*>& expected_live_allocations) {
|
||||
const std::unordered_set<void*>& expected_live_allocations) const {
|
||||
std::unique_lock<std::recursive_mutex> lock(mutex);
|
||||
|
||||
PrivatePool* pool = nullptr;
|
||||
@ -2081,7 +2081,7 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
/** Returns a copy of the memory allocator stats **/
|
||||
DeviceStats getStats() {
|
||||
DeviceStats getStats() const {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
return stats;
|
||||
}
|
||||
@ -2457,7 +2457,7 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
std::vector<TraceEntry> trace(
|
||||
const std::function<time_t(approx_time_t)>& tsc_to_us) {
|
||||
const std::function<time_t(approx_time_t)>& tsc_to_us) const {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
std::vector<TraceEntry> result;
|
||||
alloc_buffer.getEntries(result);
|
||||
@ -2593,7 +2593,7 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
int getPoolUseCount(MempoolId_t mempool_id) {
|
||||
int getPoolUseCount(MempoolId_t mempool_id) const {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
auto pp = get_private_pool(mempool_id);
|
||||
return pp->use_count;
|
||||
@ -2689,7 +2689,7 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
PrivatePool* get_private_pool(MempoolId_t mempool_id) {
|
||||
PrivatePool* get_private_pool(MempoolId_t mempool_id) const {
|
||||
auto it = graph_pools.find(mempool_id);
|
||||
TORCH_INTERNAL_ASSERT(it != graph_pools.end());
|
||||
return it->second.get();
|
||||
@ -3686,7 +3686,7 @@ class DeviceCachingAllocator {
|
||||
if (!compile_context.empty()) {
|
||||
compile_string = compile_context.top();
|
||||
}
|
||||
auto te = TraceEntry(
|
||||
TraceEntry te(
|
||||
action,
|
||||
device,
|
||||
addr,
|
||||
|
||||
@ -439,10 +439,6 @@ function(torch_compile_options libname)
|
||||
$<$<COMPILE_LANGUAGE:CXX>: -fvisibility=hidden>)
|
||||
endif()
|
||||
|
||||
# Use -O2 for release builds (-O3 doesn't improve perf, and -Os results in perf regression)
|
||||
target_compile_options(${libname} PRIVATE
|
||||
$<$<AND:$<COMPILE_LANGUAGE:CXX>,$<OR:$<CONFIG:Release>,$<CONFIG:RelWithDebInfo>>>:-O2>)
|
||||
|
||||
endfunction()
|
||||
|
||||
##############################################################################
|
||||
@ -530,4 +526,4 @@ function(target_link_options_if_supported tgt flag)
|
||||
else()
|
||||
message(WARNING "Attempted to use unsupported link option : ${flag}.")
|
||||
endif()
|
||||
endfunction()
|
||||
endfunction()
|
||||
|
||||
@ -553,42 +553,6 @@ coverage_ignore_functions = [
|
||||
# torch.distributed.checkpoint.utils
|
||||
"find_state_dict_object",
|
||||
"find_tensor_shard",
|
||||
# torch.distributed.collective_utils
|
||||
"all_gather",
|
||||
"all_gather_object_enforce_type",
|
||||
"broadcast",
|
||||
# torch.distributed.distributed_c10d
|
||||
"all_gather",
|
||||
"all_gather_coalesced",
|
||||
"all_gather_into_tensor",
|
||||
"all_gather_object",
|
||||
"all_reduce",
|
||||
"all_reduce_coalesced",
|
||||
"all_to_all",
|
||||
"all_to_all_single",
|
||||
"barrier",
|
||||
"batch_isend_irecv",
|
||||
"broadcast",
|
||||
"broadcast_object_list",
|
||||
"destroy_process_group",
|
||||
"gather",
|
||||
"gather_object",
|
||||
"get_backend",
|
||||
"get_backend_config",
|
||||
"get_global_rank",
|
||||
"get_group_rank",
|
||||
"get_process_group_ranks",
|
||||
"get_rank",
|
||||
"get_world_size",
|
||||
"init_process_group",
|
||||
"irecv",
|
||||
"is_backend_available",
|
||||
"is_gloo_available",
|
||||
"is_initialized",
|
||||
"is_mpi_available",
|
||||
"is_nccl_available",
|
||||
"is_torchelastic_launched",
|
||||
"is_ucc_available",
|
||||
"isend",
|
||||
"monitored_barrier",
|
||||
"new_group",
|
||||
@ -662,15 +626,8 @@ coverage_ignore_functions = [
|
||||
"transformer_auto_wrap_policy",
|
||||
"wrap",
|
||||
# torch.distributed.nn.functional
|
||||
"all_gather",
|
||||
"all_reduce",
|
||||
"all_to_all",
|
||||
"all_to_all_single",
|
||||
"broadcast",
|
||||
"gather",
|
||||
"reduce",
|
||||
"reduce_scatter",
|
||||
"scatter",
|
||||
# torch.distributed.nn.jit.instantiator
|
||||
"get_arg_return_types_from_interface",
|
||||
"instantiate_non_scriptable_remote_module_template",
|
||||
|
||||
@ -10,6 +10,7 @@ torch.cpu
|
||||
current_device
|
||||
current_stream
|
||||
is_available
|
||||
is_initialized
|
||||
synchronize
|
||||
stream
|
||||
set_device
|
||||
|
||||
@ -221,6 +221,16 @@ inconsistent 'UUID' assignment across ranks, and to prevent races during initial
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.distributed.distributed_c10d.is_xccl_available
|
||||
.. autofunction:: torch.distributed.distributed_c10d.batch_isend_irecv
|
||||
.. autofunction:: torch.distributed.distributed_c10d.destroy_process_group
|
||||
.. autofunction:: torch.distributed.distributed_c10d.is_backend_available
|
||||
.. autofunction:: torch.distributed.distributed_c10d.irecv
|
||||
.. autofunction:: torch.distributed.distributed_c10d.is_gloo_available
|
||||
.. autofunction:: torch.distributed.distributed_c10d.is_initialized
|
||||
.. autofunction:: torch.distributed.distributed_c10d.is_mpi_available
|
||||
.. autofunction:: torch.distributed.distributed_c10d.is_nccl_available
|
||||
.. autofunction:: torch.distributed.distributed_c10d.is_torchelastic_launched
|
||||
.. autofunction:: torch.distributed.distributed_c10d.is_ucc_available
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
@ -218,3 +218,13 @@ DataParallel functions (multi-GPU, distributed)
|
||||
:nosignatures:
|
||||
|
||||
torch.nn.parallel.data_parallel
|
||||
|
||||
Low-Precision functions
|
||||
-----------------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
ScalingType
|
||||
SwizzleType
|
||||
scaled_mm
|
||||
|
||||
@ -1,6 +1,65 @@
|
||||
# LibTorch Stable ABI
|
||||
|
||||
This note will eventually contain more details on how to use the APIs in torch/csrc/stable. For the moment, it contains a table of internal representations:
|
||||
## Overview
|
||||
|
||||
The LibTorch Stable ABI (Application Binary Interface) provides an interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases.
|
||||
|
||||
The stable ABI consists of three main components:
|
||||
|
||||
1. **Stable C headers** - Low-level C API implemented by libtorch (primarily `torch/csrc/inductor/aoti_torch/c/shim.h`)
|
||||
2. **Header-only C++ library** - Standalone utilities implemented in only headers such that there is no dependence on libtorch (`torch/headeronly/*`)
|
||||
3. **Stable C++ wrappers** - High-level C++ convenience wrappers (`torch/csrc/stable/*`)
|
||||
|
||||
We discuss each of these in detail
|
||||
|
||||
### `torch/headeronly`
|
||||
|
||||
This is a set of inlined C++ headers are completely decoupled from libtorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the
|
||||
`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`.
|
||||
|
||||
### `torch/csrc/stable`
|
||||
|
||||
This is a set of inlined C++ headers that provide wrappers around the C API that handle the rough edges
|
||||
discussed below.
|
||||
|
||||
It consists of
|
||||
|
||||
- torch/csrc/stable/library.h: Provides a stable version of TORCH_LIBRARY and similar macros.
|
||||
- torch/csrc/stable/tensor_struct.h: Provides torch::stable::Tensor, a stable version of at::Tensor.
|
||||
- torch/csrc/stable/ops.h: Provides a stable interface for calling ATen ops from `native_functions.yaml`.
|
||||
- torch/csrc/stable/accelerator.h: Provides a stable interface for device-generic objects and APIs
|
||||
(e.g. `getCurrentStream`, `DeviceGuard`).
|
||||
|
||||
We are continuing to improve coverage in our `torch/csrc/stable` APIs. Please file an issue if you'd like to see support for particular APIs in your custom extension.
|
||||
|
||||
### Stable C headers
|
||||
|
||||
The stable C headers used by AOTInductor form the foundation of the stable ABI. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs.
|
||||
Further, the stack-based APIs discussed below which allow the user to call the PyTorch dispatcher don't provide strong guarantees on forward and backward compatibility.
|
||||
|
||||
Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable`
|
||||
which will handle all the rough edges of the C API for the user.
|
||||
|
||||
|
||||
## How are objects passed across the ABI boundary when interacting with the dispatcher?
|
||||
|
||||
When interacting with the dispatcher via the stable APIs (``STABLE_TORCH_LIBRARY`` etc.) we use a boxed convention. Arguments and returns are represented as a stack of ``StableIValue`` which correlates with a `torch::jit::stack` of IValues. We discuss the following below
|
||||
1. StableIValue Conversions
|
||||
2. StableIValue stack Conventions
|
||||
3. Stable APIs that interact with the dispatcher
|
||||
|
||||
### StableIValue Conversions
|
||||
|
||||
We provide utilities for users to convert objects to and from StableIValues with the synonymous
|
||||
`to` and `from` APIs in `torch/csrc/stable/stableivalue_conversions.h`. We document the stable custom extension representation, libtorch representation and StableIValue
|
||||
representations below. Our confidently supported types are the ones in the table that have completed
|
||||
rows. You can rely on this subset for proper ABI stability, meaning that you can call `to<T_custom_ext>(arg/ret)` or `from(T)` on these types.
|
||||
|
||||
For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only.
|
||||
|
||||
You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`.
|
||||
|
||||
|
||||
1. type in custom extension: type used within the end user custom library.
|
||||
2. StableIValue representation: a stable conversion of the type to liaison between the user model vs libtorch.so in an ABI-stable manner.
|
||||
3. type in libtorch: type used within libtorch.so (or any code binary locked with libtorch).
|
||||
@ -31,16 +90,10 @@ This note will eventually contain more details on how to use the APIs in torch/c
|
||||
| ? | ? | c10::SymBool | SymBool |
|
||||
| ? | ? | at::QScheme | QScheme |
|
||||
|
||||
Our confidently supported types are the ones in the table that have completed rows. You can rely on this subset for proper ABI stability.
|
||||
|
||||
For a limited set of use cases, we also implicitly support any literal type that is representable within 64 bits as StableIValues, as the default reinterpret_cast will succeed. (For example: c10::Device.) These types are currently ABI-stable on best effort but might break in the future and thus should be used for short term testing only.
|
||||
### Stack Conventions
|
||||
|
||||
You can always work with StableIValue abstractions in your custom kernel for types such as c10::Device even if there is no standard defined representation of device in custom extensions by not introspecting into the StableIValue. For example, a custom operator can take as argument a StableIValue device and directly pass it through to an aten operator with `aoti_torch_call_dispatcher`.
|
||||
|
||||
|
||||
## How to use stack-based APIs
|
||||
|
||||
`aoti_torch_call_dispatcher` is what we consider a stack-based API because it takes as input a stack of StableIValues, which correlates with a `torch::jit::stack` of IValues. Working with the dispatcher will likely bring you into proximity with stack-based APIs, so we are documenting some invariants:
|
||||
There are two invariants for the stack:
|
||||
|
||||
1. The stack is populated left to right.
|
||||
a. For example, a stack representing arguments `arg0`, `arg1`, and `arg2` will have `arg0` at index 0, `arg1` at index 1, and `arg2` at index 2.
|
||||
@ -49,3 +102,32 @@ You can always work with StableIValue abstractions in your custom kernel for typ
|
||||
2. The stack always has ownership of the objects it holds.
|
||||
a. When calling a stack-based API, you must give owning references to the calling stack and steal references from the returned stack.
|
||||
b. When registering your function to be called with a stack, you must steal references from your argument stack and push onto the stack new references.
|
||||
|
||||
### Stack-based APIs
|
||||
|
||||
The above is relevant in two places:
|
||||
|
||||
1. `STABLE_TORCH_LIBRARY`
|
||||
Unlike `TORCH_LIBRARY`, the dispatcher expects kernels registered via `STABLE_TORCH_LIBRARY` to be boxed. This means they must have the signature `(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) -> void`.We plan to eventually abstract away the need for manual boxing, but, for the time being, please use `from` and `to`.
|
||||
|
||||
```cpp
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
std::vector<int64_t> v = {0,1};
|
||||
return amax(t, v, false);
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax_vec(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
```
|
||||
|
||||
2. `aoti_torch_call_dispatcher`
|
||||
This API allows you to call the PyTorch dispatcher from C/C++ code. It has the following signature:
|
||||
```cpp
|
||||
aoti_torch_call_dispatcher(const char* opName, const char* overloadName, StableIValue* stack);
|
||||
```
|
||||
|
||||
`aoti_torch_call_dispatcher` will call the op overload defined by a given `opName`, `overloadName`, and a stack of
|
||||
StableIValues. This call will populate any return values of the op into the stack in their StableIValue form,
|
||||
with `ret0` at index 0, `ret1` at index 1, and so on.
|
||||
|
||||
@ -35,7 +35,6 @@ and supported quantized modules and functions.
|
||||
|
||||
quantization-support
|
||||
|
||||
|
||||
.. torch.ao is missing documentation. Since part of it is mentioned here, adding them here for now.
|
||||
.. They are here for tracking purposes until they are more permanently fixed.
|
||||
.. py:module:: torch.ao
|
||||
|
||||
@ -20,8 +20,10 @@ project-includes = [
|
||||
project-excludes = [
|
||||
# ==== below will be enabled directory by directory ====
|
||||
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
||||
"torch/_inductor/**",
|
||||
# formatting issues
|
||||
"torch/_inductor/runtime",
|
||||
"torch/_inductor/codegen",
|
||||
# formatting issues, will turn on after adjusting where suppressions can be
|
||||
# in import statements
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
"torch/package/_package_pickler.py",
|
||||
@ -31,6 +33,9 @@ project-excludes = [
|
||||
"torch/_export/utils.py",
|
||||
"torch/fx/experimental/unification/multipledispatch/__init__.py",
|
||||
"torch/nn/modules/__init__.py",
|
||||
"torch/nn/modules/rnn.py", # only remove when parsing errors are fixed
|
||||
"torch/_inductor/codecache.py",
|
||||
"torch/distributed/elastic/metrics/__init__.py",
|
||||
# ====
|
||||
"benchmarks/instruction_counts/main.py",
|
||||
"benchmarks/instruction_counts/definitions/setup.py",
|
||||
|
||||
@ -89,7 +89,7 @@ if venv_dir.exists():
|
||||
print("Removing existing hook venv...")
|
||||
shutil.rmtree(venv_dir)
|
||||
|
||||
run(["uv", "venv", str(venv_dir), "--python", "3.9"])
|
||||
run(["uv", "venv", str(venv_dir), "--python", "3.10"])
|
||||
|
||||
# Install lintrunner in the isolated environment
|
||||
print("Installing lintrunner in isolated environment...")
|
||||
|
||||
2
setup.py
2
setup.py
@ -225,7 +225,7 @@
|
||||
#
|
||||
# USE_MIMALLOC
|
||||
# Static link mimalloc into C10, and use mimalloc in alloc_cpu & alloc_free.
|
||||
# By default, It is only enabled on Windows.
|
||||
# By default, It is only enabled on Windows and AArch64.
|
||||
#
|
||||
# BUILD_LIBTORCH_WHL
|
||||
# Builds libtorch.so and its dependencies as a wheel
|
||||
|
||||
@ -1111,6 +1111,14 @@
|
||||
"_amp_update_scale_",
|
||||
"_assert_async",
|
||||
"_batch_norm_impl_index",
|
||||
"_cast_Byte",
|
||||
"_cast_Char",
|
||||
"_cast_Double",
|
||||
"_cast_Float",
|
||||
"_cast_Half",
|
||||
"_cast_Int",
|
||||
"_cast_Long",
|
||||
"_cast_Short",
|
||||
"_choose_qparams_per_tensor",
|
||||
"_coalesce",
|
||||
"_compute_linear_combination",
|
||||
|
||||
@ -1292,12 +1292,6 @@ torch::Tensor view_op(const torch::Tensor& self) {
|
||||
return self.alias();
|
||||
}
|
||||
|
||||
torch::Tensor view_op_with_extra_arg(
|
||||
const torch::Tensor& self,
|
||||
const torch::Tensor& other) {
|
||||
return self.alias();
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> ret_tensor_vector_view(
|
||||
const torch::Tensor& self,
|
||||
const torch::Tensor& other) {
|
||||
@ -1534,35 +1528,9 @@ TEST(TestAutogradNotImplementedFallback, ViewOp) {
|
||||
// Test inplace on view
|
||||
auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
|
||||
|
||||
// raise on rebase_history when it refreshes grad_fn
|
||||
ASSERT_THROWS_WITH(
|
||||
v1.add_(t), "which does not have a derivative implemented is forbidden");
|
||||
// base should not be aware of the views, so this is still okay
|
||||
// this works as we can properly replay the view given by the user
|
||||
v1.add_(t);
|
||||
b1.add_(t);
|
||||
ASSERT_THROWS_WITH(
|
||||
v1.grad_fn(),
|
||||
"which does not have a derivative implemented is forbidden");
|
||||
}
|
||||
|
||||
TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
|
||||
REGISTER_TEST_OP(
|
||||
"view_op_with_extra_arg",
|
||||
"_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
|
||||
view_op_with_extra_arg);
|
||||
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
|
||||
"_test::view_op_with_extra_arg", "");
|
||||
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
|
||||
return callOpUnboxed<
|
||||
torch::Tensor,
|
||||
const torch::Tensor&,
|
||||
const torch::Tensor&>(opHandle, _1, _2);
|
||||
};
|
||||
assertBasicChecks(op);
|
||||
auto a = torch::tensor({1.}, {torch::kFloat32});
|
||||
auto b = torch::tensor({2.}, {torch::kFloat32});
|
||||
auto out1 = op(a, b);
|
||||
ASSERT_TRUE(out1.is_view());
|
||||
ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
|
||||
}
|
||||
|
||||
TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
|
||||
|
||||
@ -564,508 +564,3 @@ TEST(OptimTest, CheckLRChange_ReduceLROnPlateau_Adam) {
|
||||
check_lr_change_for_reduce_on_plateau(
|
||||
optimizer, reduce_lr_on_plateau_scheduler, expected_epoch_lrs);
|
||||
}
|
||||
// Tests for Issue 141884: Parameter group inheritance functionality
|
||||
// Validates that partial options in parameter groups correctly inherit
|
||||
// defaults from the optimizer while preserving explicitly set values
|
||||
TEST(OptimTest, MergeWithDefaultOptions_Adam) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only weight_decay specified, should inherit lr, betas, eps,
|
||||
// amsgrad
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdamOptions>(AdamOptions().weight_decay(0.11)));
|
||||
|
||||
// Group 2: Only eps specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamOptions defaults;
|
||||
defaults.lr(0.002)
|
||||
.betas(std::make_tuple(0.8, 0.88))
|
||||
.eps(1e-12)
|
||||
.weight_decay(0.05)
|
||||
.amsgrad(true);
|
||||
|
||||
Adam optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: weight_decay preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<AdamOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.002); // Inherited
|
||||
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-12); // Inherited
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.11); // Preserved
|
||||
ASSERT_TRUE(group1_opts.amsgrad()); // Inherited
|
||||
|
||||
// Check Group 2: eps preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<AdamOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.002); // Inherited
|
||||
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-6); // Preserved
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
|
||||
ASSERT_TRUE(group2_opts.amsgrad()); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_SGD) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only lr and weight_decay specified, should inherit momentum,
|
||||
// dampening, nesterov
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<SGDOptions>(SGDOptions(0.01).weight_decay(0.22)));
|
||||
|
||||
// Group 2: Only lr specified, should inherit all others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<SGDOptions>(SGDOptions(0.02)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
SGDOptions defaults(0.001); // lr should be overridden by param groups
|
||||
defaults.momentum(0.9)
|
||||
.dampening(0.0) // Must be 0 for Nesterov
|
||||
.weight_decay(0.05)
|
||||
.nesterov(true);
|
||||
|
||||
SGD optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: lr and weight_decay preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<SGDOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.01); // Preserved
|
||||
ASSERT_EQ(group1_opts.momentum(), 0.9); // Inherited
|
||||
ASSERT_EQ(group1_opts.dampening(), 0.0); // Inherited
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.22); // Preserved
|
||||
ASSERT_TRUE(group1_opts.nesterov()); // Inherited
|
||||
|
||||
// Check Group 2: lr preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<SGDOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.02); // Preserved
|
||||
ASSERT_EQ(group2_opts.momentum(), 0.9); // Inherited
|
||||
ASSERT_EQ(group2_opts.dampening(), 0.0); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.05); // Inherited
|
||||
ASSERT_TRUE(group2_opts.nesterov()); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_AdamW) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only eps specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdamWOptions>(AdamWOptions().eps(1e-6)));
|
||||
|
||||
// Group 2: Only betas specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdamWOptions>(
|
||||
AdamWOptions().betas(std::make_tuple(0.95, 0.999))));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamWOptions defaults;
|
||||
defaults.lr(0.003)
|
||||
.betas(std::make_tuple(0.9, 0.98))
|
||||
.eps(1e-8)
|
||||
.weight_decay(0.02)
|
||||
.amsgrad(false);
|
||||
|
||||
AdamW optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: eps preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<AdamWOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.003); // Inherited
|
||||
ASSERT_EQ(group1_opts.betas(), std::make_tuple(0.9, 0.98)); // Inherited
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-6); // Preserved
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.02); // Inherited
|
||||
ASSERT_FALSE(group1_opts.amsgrad()); // Inherited
|
||||
|
||||
// Check Group 2: betas preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<AdamWOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.003); // Inherited
|
||||
ASSERT_EQ(group2_opts.betas(), std::make_tuple(0.95, 0.999)); // Preserved
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-8); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.02); // Inherited
|
||||
ASSERT_FALSE(group2_opts.amsgrad()); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_Adagrad) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only lr_decay specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdagradOptions>(AdagradOptions().lr_decay(0.001)));
|
||||
|
||||
// Group 2: Only initial_accumulator_value specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdagradOptions>(
|
||||
AdagradOptions().initial_accumulator_value(0.5)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdagradOptions defaults;
|
||||
defaults.lr(0.04)
|
||||
.lr_decay(0.002)
|
||||
.weight_decay(0.03)
|
||||
.initial_accumulator_value(0.1)
|
||||
.eps(1e-11);
|
||||
|
||||
Adagrad optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: lr_decay preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<AdagradOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.04); // Inherited
|
||||
ASSERT_EQ(group1_opts.lr_decay(), 0.001); // Preserved
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.03); // Inherited
|
||||
ASSERT_EQ(group1_opts.initial_accumulator_value(), 0.1); // Inherited
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-11); // Inherited
|
||||
|
||||
// Check Group 2: initial_accumulator_value preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<AdagradOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.04); // Inherited
|
||||
ASSERT_EQ(group2_opts.lr_decay(), 0.002); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.03); // Inherited
|
||||
ASSERT_EQ(group2_opts.initial_accumulator_value(), 0.5); // Preserved
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-11); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_RMSprop) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only alpha specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<RMSpropOptions>(RMSpropOptions().alpha(0.95)));
|
||||
|
||||
// Group 2: Only momentum and centered specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<RMSpropOptions>(
|
||||
RMSpropOptions().momentum(0.8).centered(true)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
RMSpropOptions defaults;
|
||||
defaults.lr(0.015)
|
||||
.alpha(0.98)
|
||||
.eps(1e-9)
|
||||
.weight_decay(0.01)
|
||||
.momentum(0.7)
|
||||
.centered(false);
|
||||
|
||||
RMSprop optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group 1: alpha preserved, others inherited
|
||||
auto& group1_opts =
|
||||
static_cast<RMSpropOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group1_opts.lr(), 0.015); // Inherited
|
||||
ASSERT_EQ(group1_opts.alpha(), 0.95); // Preserved
|
||||
ASSERT_EQ(group1_opts.eps(), 1e-9); // Inherited
|
||||
ASSERT_EQ(group1_opts.weight_decay(), 0.01); // Inherited
|
||||
ASSERT_EQ(group1_opts.momentum(), 0.7); // Inherited
|
||||
ASSERT_FALSE(group1_opts.centered()); // Inherited
|
||||
|
||||
// Check Group 2: momentum and centered preserved, others inherited
|
||||
auto& group2_opts =
|
||||
static_cast<RMSpropOptions&>(optimizer.param_groups()[1].options());
|
||||
ASSERT_EQ(group2_opts.lr(), 0.015); // Inherited
|
||||
ASSERT_EQ(group2_opts.alpha(), 0.98); // Inherited
|
||||
ASSERT_EQ(group2_opts.eps(), 1e-9); // Inherited
|
||||
ASSERT_EQ(group2_opts.weight_decay(), 0.01); // Inherited
|
||||
ASSERT_EQ(group2_opts.momentum(), 0.8); // Preserved
|
||||
ASSERT_TRUE(group2_opts.centered()); // Preserved
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_LBFGS) {
|
||||
// Create tensors for single parameter group (LBFGS limitation)
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param group with partial options
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Single group: Only max_iter specified, should inherit others
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{
|
||||
tensor1, tensor2}, // Combine tensors in single group
|
||||
std::make_unique<LBFGSOptions>(LBFGSOptions().max_iter(15)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
LBFGSOptions defaults;
|
||||
defaults.lr(0.8)
|
||||
.max_iter(25)
|
||||
.max_eval(31) // Use same value that appears to be auto-calculated
|
||||
.tolerance_grad(1e-5)
|
||||
.tolerance_change(1e-8)
|
||||
.history_size(80)
|
||||
.line_search_fn("strong_wolfe");
|
||||
|
||||
LBFGS optimizer(param_groups, defaults);
|
||||
|
||||
// Check Group: max_iter preserved, others inherited
|
||||
auto& group_opts =
|
||||
static_cast<LBFGSOptions&>(optimizer.param_groups()[0].options());
|
||||
ASSERT_EQ(group_opts.lr(), 0.8); // Inherited
|
||||
ASSERT_EQ(group_opts.max_iter(), 15); // Preserved
|
||||
ASSERT_EQ(group_opts.max_eval(), 31); // Inherited
|
||||
ASSERT_EQ(group_opts.tolerance_grad(), 1e-5); // Inherited
|
||||
ASSERT_EQ(group_opts.tolerance_change(), 1e-8); // Inherited
|
||||
ASSERT_EQ(group_opts.history_size(), 80); // Inherited
|
||||
ASSERT_EQ(group_opts.line_search_fn(), "strong_wolfe"); // Inherited
|
||||
}
|
||||
|
||||
TEST(OptimTest, MergeWithDefaultOptions_NoOptionsInheritance) {
|
||||
// Test that param groups without options get full defaults
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Groups with no options - should inherit everything
|
||||
param_groups.emplace_back(std::vector<torch::Tensor>{tensor1});
|
||||
param_groups.emplace_back(std::vector<torch::Tensor>{tensor2});
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamOptions defaults;
|
||||
defaults.lr(0.005)
|
||||
.betas(std::make_tuple(0.85, 0.95))
|
||||
.eps(1e-7)
|
||||
.weight_decay(0.08)
|
||||
.amsgrad(true);
|
||||
|
||||
Adam optimizer(param_groups, defaults);
|
||||
|
||||
// Both groups should have exactly the default options
|
||||
for (int i = 0; i < 2; i++) {
|
||||
auto& group_opts =
|
||||
static_cast<AdamOptions&>(optimizer.param_groups()[i].options());
|
||||
ASSERT_EQ(group_opts.lr(), 0.005);
|
||||
ASSERT_EQ(group_opts.betas(), std::make_tuple(0.85, 0.95));
|
||||
ASSERT_EQ(group_opts.eps(), 1e-7);
|
||||
ASSERT_EQ(group_opts.weight_decay(), 0.08);
|
||||
ASSERT_TRUE(group_opts.amsgrad());
|
||||
}
|
||||
}
|
||||
|
||||
// Test that field tracking survives serialization/deserialization cycles
|
||||
TEST(OptimTest, SerializationPreservesFieldTracking_Adam) {
|
||||
// Create tensors for parameter groups
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
// Create param groups with partial options using fluent API (marks fields as
|
||||
// explicit)
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
|
||||
// Group 1: Only weight_decay and amsgrad explicitly set via fluent API
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<AdamOptions>(
|
||||
AdamOptions().weight_decay(0.11).amsgrad(true)));
|
||||
|
||||
// Group 2: Only eps explicitly set via fluent API
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<AdamOptions>(AdamOptions().eps(1e-6)));
|
||||
|
||||
// Create optimizer with specific defaults
|
||||
AdamOptions defaults;
|
||||
defaults.lr(0.002)
|
||||
.betas(std::make_tuple(0.8, 0.88))
|
||||
.eps(1e-12)
|
||||
.weight_decay(0.05)
|
||||
.amsgrad(false);
|
||||
|
||||
Adam original_optimizer(param_groups, defaults);
|
||||
|
||||
// Capture original state for comparison
|
||||
auto& orig_group1_opts =
|
||||
static_cast<AdamOptions&>(original_optimizer.param_groups()[0].options());
|
||||
auto& orig_group2_opts =
|
||||
static_cast<AdamOptions&>(original_optimizer.param_groups()[1].options());
|
||||
|
||||
// Verify original state (sanity check)
|
||||
ASSERT_NEAR(orig_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
|
||||
ASSERT_TRUE(orig_group1_opts.amsgrad()); // Explicitly set
|
||||
ASSERT_NEAR(orig_group1_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
ASSERT_NEAR(orig_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
|
||||
ASSERT_NEAR(orig_group2_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
|
||||
// Test serialization of the options objects (where field tracking lives)
|
||||
std::stringstream ss1, ss2;
|
||||
|
||||
// Serialize the parameter group options
|
||||
{
|
||||
torch::serialize::OutputArchive archive;
|
||||
orig_group1_opts.serialize(archive);
|
||||
archive.save_to(ss1);
|
||||
}
|
||||
{
|
||||
torch::serialize::OutputArchive archive;
|
||||
orig_group2_opts.serialize(archive);
|
||||
archive.save_to(ss2);
|
||||
}
|
||||
|
||||
// Create new options objects and deserialize
|
||||
AdamOptions loaded_group1_opts;
|
||||
AdamOptions loaded_group2_opts;
|
||||
|
||||
{
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(ss1);
|
||||
loaded_group1_opts.serialize(archive);
|
||||
}
|
||||
{
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(ss2);
|
||||
loaded_group2_opts.serialize(archive);
|
||||
}
|
||||
|
||||
// Verify that all parameter values are preserved after deserialization
|
||||
|
||||
// Group 1: weight_decay and amsgrad should be preserved as explicitly set,
|
||||
// others inherited
|
||||
ASSERT_NEAR(loaded_group1_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
ASSERT_EQ(
|
||||
loaded_group1_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_NEAR(loaded_group1_opts.eps(), 1e-12, 1e-15); // Inherited
|
||||
ASSERT_NEAR(loaded_group1_opts.weight_decay(), 0.11, 1e-6); // Explicitly set
|
||||
ASSERT_TRUE(loaded_group1_opts.amsgrad()); // Explicitly set
|
||||
|
||||
// Group 2: eps should be preserved as explicitly set, others inherited
|
||||
ASSERT_NEAR(loaded_group2_opts.lr(), 0.002, 1e-6); // Inherited
|
||||
ASSERT_EQ(
|
||||
loaded_group2_opts.betas(), std::make_tuple(0.8, 0.88)); // Inherited
|
||||
ASSERT_NEAR(loaded_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set
|
||||
ASSERT_NEAR(loaded_group2_opts.weight_decay(), 0.05, 1e-6); // Inherited
|
||||
ASSERT_FALSE(loaded_group2_opts.amsgrad()); // Inherited
|
||||
|
||||
// CRITICAL: Test that field tracking is preserved after serialization
|
||||
// Create a new optimizer using the deserialized options to test inheritance
|
||||
auto tensor3 = torch::randn({2, 2}).requires_grad_(true);
|
||||
auto tensor4 = torch::randn({3, 3}).requires_grad_(true);
|
||||
|
||||
std::vector<OptimizerParamGroup> test_param_groups;
|
||||
test_param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor3},
|
||||
std::make_unique<AdamOptions>(loaded_group1_opts));
|
||||
test_param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor4},
|
||||
std::make_unique<AdamOptions>(loaded_group2_opts));
|
||||
|
||||
Adam test_optimizer(test_param_groups, defaults);
|
||||
|
||||
// The field tracking should work correctly for inheritance
|
||||
auto& final_group1_opts =
|
||||
static_cast<AdamOptions&>(test_optimizer.param_groups()[0].options());
|
||||
auto& final_group2_opts =
|
||||
static_cast<AdamOptions&>(test_optimizer.param_groups()[1].options());
|
||||
|
||||
// Group 1: weight_decay and amsgrad should still be preserved as explicitly
|
||||
// set
|
||||
ASSERT_NEAR(
|
||||
final_group1_opts.weight_decay(),
|
||||
0.11,
|
||||
1e-6); // Explicitly set (preserved)
|
||||
ASSERT_TRUE(final_group1_opts.amsgrad()); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(final_group1_opts.lr(), 0.002, 1e-6); // Inherited from defaults
|
||||
|
||||
// Group 2: eps should still be preserved as explicitly set
|
||||
ASSERT_NEAR(
|
||||
final_group2_opts.eps(), 1e-6, 1e-9); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(final_group2_opts.lr(), 0.002, 1e-6); // Inherited from defaults
|
||||
}
|
||||
|
||||
// Test serialization with SGD (different parameter types)
|
||||
TEST(OptimTest, SerializationPreservesFieldTracking_SGD) {
|
||||
// Create tensors
|
||||
auto tensor1 = torch::randn({2, 2}).requires_grad_(true);
|
||||
|
||||
// Create param group with partial options using fluent API
|
||||
std::vector<OptimizerParamGroup> param_groups;
|
||||
param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor1},
|
||||
std::make_unique<SGDOptions>(
|
||||
SGDOptions(0.01).weight_decay(0.22).nesterov(true)));
|
||||
|
||||
// Create optimizer with defaults
|
||||
SGDOptions defaults(0.001);
|
||||
defaults.momentum(0.9).dampening(0.0).weight_decay(0.05).nesterov(false);
|
||||
|
||||
SGD original_optimizer(param_groups, defaults);
|
||||
|
||||
// Test serialization of the SGD options (where field tracking lives)
|
||||
auto& original_opts =
|
||||
static_cast<SGDOptions&>(original_optimizer.param_groups()[0].options());
|
||||
|
||||
std::stringstream ss;
|
||||
{
|
||||
torch::serialize::OutputArchive archive;
|
||||
original_opts.serialize(archive);
|
||||
archive.save_to(ss);
|
||||
}
|
||||
|
||||
SGDOptions loaded_opts(0.0); // Dummy initial value
|
||||
{
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(ss);
|
||||
loaded_opts.serialize(archive);
|
||||
}
|
||||
ASSERT_NEAR(loaded_opts.lr(), 0.01, 1e-6); // Explicitly set
|
||||
ASSERT_NEAR(loaded_opts.momentum(), 0.9, 1e-6); // Inherited
|
||||
ASSERT_NEAR(loaded_opts.dampening(), 0.0, 1e-6); // Inherited
|
||||
ASSERT_NEAR(loaded_opts.weight_decay(), 0.22, 1e-6); // Explicitly set
|
||||
ASSERT_TRUE(loaded_opts.nesterov()); // Explicitly set
|
||||
|
||||
// Test that field tracking still works after deserialization by creating new
|
||||
// optimizer
|
||||
auto tensor2 = torch::randn({3, 3}).requires_grad_(true);
|
||||
std::vector<OptimizerParamGroup> test_param_groups;
|
||||
test_param_groups.emplace_back(
|
||||
std::vector<torch::Tensor>{tensor2},
|
||||
std::make_unique<SGDOptions>(loaded_opts));
|
||||
|
||||
SGD test_optimizer(test_param_groups, defaults);
|
||||
|
||||
auto& final_opts =
|
||||
static_cast<SGDOptions&>(test_optimizer.param_groups()[0].options());
|
||||
ASSERT_NEAR(final_opts.lr(), 0.01, 1e-6); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(
|
||||
final_opts.weight_decay(), 0.22, 1e-6); // Explicitly set (preserved)
|
||||
ASSERT_TRUE(final_opts.nesterov()); // Explicitly set (preserved)
|
||||
ASSERT_NEAR(final_opts.momentum(), 0.9, 1e-6); // Inherited from defaults
|
||||
ASSERT_NEAR(final_opts.dampening(), 0.0, 1e-6); // Inherited from defaults
|
||||
}
|
||||
|
||||
@ -135,6 +135,84 @@ TEST_F(LazyOpsTest, TestIsSigned) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastByte) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Byte(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Byte(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastChar) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Char(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Char(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastShort) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Short(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Short(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastInt) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Int(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Int(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastLong) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Long(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Long(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestCastFloat) {
|
||||
torch::Tensor a =
|
||||
torch::rand(
|
||||
{2, 2}, torch::TensorOptions(torch::kFloat).device(DefaultDevice())) *
|
||||
100.0;
|
||||
torch::Tensor b = torch::_cast_Float(a);
|
||||
ForEachDevice([&](const torch::Device& device) {
|
||||
torch::Tensor lazy_a = CopyToDevice(a, device);
|
||||
torch::Tensor lazy_b = torch::_cast_Float(lazy_a);
|
||||
AllEqual(b, lazy_b);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestRetainType) {
|
||||
torch::Tensor lazy_a = torch::zeros(
|
||||
{2, 2}, torch::TensorOptions(torch::kByte).device(torch::kLazy));
|
||||
|
||||
@ -32,7 +32,7 @@ from torch.testing._internal.common_distributed import (
|
||||
sm_is_or_higher_than,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
ModelArgs,
|
||||
Transformer,
|
||||
@ -133,7 +133,11 @@ class TestFullyShardCompile(FSDPTest):
|
||||
device_type.type,
|
||||
self.rank % torch.get_device_module(device_type).device_count(),
|
||||
)
|
||||
if device_type.type == "cuda" and not sm_is_or_higher_than(device, 8, 0):
|
||||
if (
|
||||
device_type.type == "cuda"
|
||||
and not torch.version.hip
|
||||
and not sm_is_or_higher_than(device, 8, 0)
|
||||
):
|
||||
self.skipTest("bf16 requires sm >= 8.0")
|
||||
|
||||
def test_dynamo_trace_use_training_state(self):
|
||||
@ -478,7 +482,6 @@ val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
|
||||
file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.")
|
||||
return file_check
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_compiled_autograd_ctx(self):
|
||||
self.skipTestForOldSm()
|
||||
@ -643,14 +646,12 @@ Unsupported Tensor.backward() call
|
||||
|
||||
return model_init_fn, input_creation_fn
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_simple_mlp_fullgraph_backend_aot_eager(self):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_simple_mlp_factory_fns(), "aot_eager", fwd_fullgraph=True
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
|
||||
self._test_traceable_fsdp(
|
||||
@ -659,7 +660,6 @@ Unsupported Tensor.backward() call
|
||||
fwd_fullgraph=True,
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_simple_mlp_fullgraph_backend_inductor(self):
|
||||
self.skipTestForOldSm()
|
||||
@ -731,7 +731,6 @@ Unsupported Tensor.backward() call
|
||||
|
||||
return model_init_fn, input_creation_fn
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_nested_fully_shard_backend_aot_eager(self):
|
||||
# TODO: fix fwd_fullgraph=False case
|
||||
@ -744,7 +743,6 @@ Unsupported Tensor.backward() call
|
||||
fwd_fullgraph=fwd_fullgraph,
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_nested_fully_shard_backend_aot_eager_decomp_partition(self):
|
||||
# TODO: fix fwd_fullgraph=False case
|
||||
@ -866,19 +864,16 @@ Unsupported Tensor.backward() call
|
||||
pass
|
||||
file_check.run(bwd_code)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_nested_fully_shard_backend_inductor_fullgraph_True(self):
|
||||
self._test_nested_fully_shard_backend_inductor_fullgraph_True()
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_nested_fully_shard_backend_inductor_fullgraph_True_graph_partition(self):
|
||||
self._test_nested_fully_shard_backend_inductor_fullgraph_True()
|
||||
|
||||
@unittest.skip("TODO: fix fwd_fullgraph=False case")
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_nested_fully_shard_backend_inductor_fullgraph_False(self):
|
||||
self.skipTestForOldSm()
|
||||
@ -956,7 +951,6 @@ Unsupported Tensor.backward() call
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_transformer_backend_aot_eager(self):
|
||||
# TODO: fix fwd_fullgraph=False case
|
||||
@ -975,7 +969,6 @@ Unsupported Tensor.backward() call
|
||||
fwd_fullgraph=fwd_fullgraph,
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
# TODO: native_dropout has worse accuracy after decomp, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
@ -1111,7 +1104,6 @@ Unsupported Tensor.backward() call
|
||||
file_check.run(bwd_code)
|
||||
|
||||
@unittest.skip('"Traceable FSDP2" is not being maintained anymore.')
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
# TODO: native_dropout causes CUDA IMA error, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
@ -1119,7 +1111,6 @@ Unsupported Tensor.backward() call
|
||||
self._test_transformer_backend_inductor_fullgraph_True()
|
||||
|
||||
@unittest.skip('"Traceable FSDP2" is not being maintained anymore.')
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
# TODO: native_dropout causes CUDA IMA error, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
@ -1128,7 +1119,6 @@ Unsupported Tensor.backward() call
|
||||
self._test_transformer_backend_inductor_fullgraph_True()
|
||||
|
||||
@unittest.skip("TODO: fix fwd_fullgraph=False case")
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
# TODO: native_dropout causes CUDA IMA error, need to figure out why
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed.checkpoint.state_dict import get_state_dict
|
||||
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
@ -73,8 +73,8 @@ class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin):
|
||||
self.device_type, (2, 4), mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
# TODO: we are using an internal API atm. Change to a public API once it is ready.
|
||||
mesh_fsdp_ep = _mesh_resources.create_sub_mesh(mesh_fsdp_tp, ("dp",), [(0,)])
|
||||
del _mesh_resources.child_to_root_mapping[mesh_fsdp_ep]
|
||||
mesh_fsdp_ep = mesh_fsdp_tp["dp"]
|
||||
mesh_fsdp_ep._root_mesh = None
|
||||
|
||||
mesh_fsdp = init_device_mesh(self.device_type, (8,))
|
||||
for i, l in enumerate(model.second.ep_layers):
|
||||
|
||||
@ -8,6 +8,7 @@ import os
|
||||
from model_registry import MultiMLP
|
||||
|
||||
import torch
|
||||
from torch._dynamo import OptimizedModule
|
||||
from torch.distributed.pipelining import (
|
||||
Schedule1F1B,
|
||||
ScheduleDualPipeV,
|
||||
@ -258,7 +259,15 @@ class ScheduleTest(TestCase):
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
def test_zero_bubble_schedule_errors_with_compile(self):
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[
|
||||
ScheduleInterleavedZeroBubble,
|
||||
ScheduleZBVZeroBubble,
|
||||
ScheduleDualPipeV,
|
||||
],
|
||||
)
|
||||
def test_zero_bubble_schedule_errors_with_compile(self, ScheduleClass):
|
||||
"""
|
||||
Test that zero bubble schedules raise an error when used with torch.compile.
|
||||
"""
|
||||
@ -271,16 +280,18 @@ class ScheduleTest(TestCase):
|
||||
model = MultiMLP(8, n_layers=n_stages)
|
||||
# full_mod
|
||||
compiled_model = torch.compile(model)
|
||||
self.assertTrue(isinstance(compiled_model, OptimizedModule))
|
||||
stage = PipelineStage(
|
||||
compiled_model,
|
||||
0,
|
||||
n_stages,
|
||||
device,
|
||||
)
|
||||
with self.assertRaises(RuntimeError):
|
||||
ScheduleInterleavedZeroBubble([stage], 2)
|
||||
|
||||
torch.distributed.destroy_process_group()
|
||||
try:
|
||||
with self.assertRaises(RuntimeError):
|
||||
ScheduleClass([stage], 2)
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ScheduleTest)
|
||||
|
||||
@ -4,6 +4,7 @@ import contextlib
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -214,6 +215,29 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::_unsafe_view(ft: f32[64, 8], [8, 8, 8])""",
|
||||
)
|
||||
|
||||
def test_tensor_attributes(self):
|
||||
x = torch.randn(8, 8)
|
||||
x.a1 = "x1"
|
||||
x.a2 = "x2"
|
||||
y = torch.randn(8, 8, 8)
|
||||
y.a1 = "y"
|
||||
|
||||
with DebugMode(
|
||||
record_torchfunction=True,
|
||||
record_faketensor=True,
|
||||
record_tensor_attributes=["a1", "a2"],
|
||||
) as debug_mode:
|
||||
torch.matmul(y, x)
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
torch.matmul(t: f32[8, 8, 8]{a1=y}, t: f32[8, 8]{a1=x1, a2=x2})
|
||||
aten::view(t: f32[8, 8, 8]{a1=y}, [64, 8])
|
||||
aten::mm(t: f32[64, 8], t: f32[8, 8]{a1=x1, a2=x2})
|
||||
aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""",
|
||||
)
|
||||
|
||||
@parametrize("has_inner_mode", [True, False])
|
||||
@parametrize("has_outer_mode", [True, False])
|
||||
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):
|
||||
@ -262,14 +286,21 @@ class TestDTensorDebugMode(TestCase):
|
||||
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
|
||||
|
||||
def test_compile(self):
|
||||
@torch.compile
|
||||
cnt = CompileCounterWithBackend("inductor")
|
||||
|
||||
@torch.compile(backend=cnt)
|
||||
def f(x):
|
||||
return x.sin().cos()
|
||||
|
||||
x = torch.randn(8)
|
||||
with DebugMode() as debug_mode:
|
||||
f(x)
|
||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||
f(x)
|
||||
f(x)
|
||||
self.assertEqual(
|
||||
cnt.frame_count, 1
|
||||
) # check DebugMode doesn't trigger additional recompilations
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
@ -11,7 +11,8 @@ from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._functorch.partitioners import min_cut_rematerialization_partition
|
||||
from torch._guards import tracing, TracingContext
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate
|
||||
from torch.distributed.tensor import distribute_tensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor._api import DTensor
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
@ -39,6 +40,21 @@ class SimpleModel(torch.nn.Module):
|
||||
return self.mlp_1(self.mlp_0(input))
|
||||
|
||||
|
||||
class EinsumModel(torch.nn.Module):
|
||||
"""Simple model that uses einsum with DTensor inputs and returns DTensor."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.placement = None
|
||||
|
||||
def forward(self, x, y, z):
|
||||
result = torch.einsum("bsh,hd->bsd", x, y)
|
||||
self.placement = result.placements[0]
|
||||
self.placement_2 = y.placements[0]
|
||||
self.placement_3 = z.placements[0]
|
||||
return result
|
||||
|
||||
|
||||
class SimpleModelDynamicShapes(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
@ -334,6 +350,32 @@ class DTensorExportTest(TestCase):
|
||||
"""[[4, 10], [4], [10, 4], [10], [s22, 10], [s22, 10]]""",
|
||||
)
|
||||
|
||||
def test_einsum_dtensor_export(self):
|
||||
"""Test exporting a model with einsum that has DTensor inputs/outputs with side effects"""
|
||||
world_size = 4
|
||||
# Create device mesh
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(world_size,))
|
||||
model = EinsumModel()
|
||||
|
||||
x = torch.randn(4, 8, 16)
|
||||
x_dtensor = distribute_tensor(x, device_mesh, placements=[Shard(0)])
|
||||
|
||||
# y: [16, 16] replicated
|
||||
y = torch.randn(16, 16)
|
||||
z = torch.randn(16, 16)
|
||||
y_dtensor = distribute_tensor(y, device_mesh, placements=[Replicate()])
|
||||
z_dtensor = DTensor.from_local(z, device_mesh, placements=[Partial()])
|
||||
|
||||
# Run model to verify it works
|
||||
output = model(x_dtensor, y_dtensor, z_dtensor)
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = _dynamo_graph_capture_for_export(model)(
|
||||
x_dtensor, y_dtensor, z_dtensor
|
||||
)
|
||||
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
|
||||
self.assertEqual(output, output_gm)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user