mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 07:27:32 +08:00
Compare commits
1 Commits
csl/test_o
...
test_quant
| Author | SHA1 | Date | |
|---|---|---|---|
| 46eaea0232 |
40
.ci/pytorch/functorch_doc_push_script.sh
Executable file
40
.ci/pytorch/functorch_doc_push_script.sh
Executable file
@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This is where the local pytorch install in the docker image is located
|
||||
pt_checkout="/var/lib/jenkins/workspace"
|
||||
source "$pt_checkout/.ci/pytorch/common_utils.sh"
|
||||
echo "functorch_doc_push_script.sh: Invoked with $*"
|
||||
|
||||
set -ex -o pipefail
|
||||
|
||||
version=${DOCS_VERSION:-nightly}
|
||||
echo "version: $version"
|
||||
|
||||
# Build functorch docs
|
||||
pushd $pt_checkout/functorch/docs
|
||||
make html
|
||||
popd
|
||||
|
||||
git clone https://github.com/pytorch/functorch -b gh-pages --depth 1 functorch_ghpages
|
||||
pushd functorch_ghpages
|
||||
|
||||
if [ "$version" == "main" ]; then
|
||||
version=nightly
|
||||
fi
|
||||
|
||||
git rm -rf "$version" || true
|
||||
mv "$pt_checkout/functorch/docs/build/html" "$version"
|
||||
|
||||
git add "$version" || true
|
||||
git status
|
||||
git config user.email "soumith+bot@pytorch.org"
|
||||
git config user.name "pytorchbot"
|
||||
# If there aren't changes, don't make a commit; push is no-op
|
||||
git commit -m "Generate Python docs from pytorch/pytorch@${GITHUB_SHA}" || true
|
||||
git status
|
||||
|
||||
if [[ "${WITH_PUSH:-}" == true ]]; then
|
||||
git push -u origin gh-pages
|
||||
fi
|
||||
|
||||
popd
|
||||
@ -59,7 +59,7 @@ test_python_shard() {
|
||||
|
||||
setup_test_python
|
||||
|
||||
time python test/run_test.py --verbose --exclude-jit-executor --exclude-distributed-tests --shard "$1" "$NUM_TEST_SHARDS"
|
||||
time python test/run_test.py --verbose --exclude-jit-executor --exclude-distributed-tests --exclude-quantization-tests --shard "$1" "$NUM_TEST_SHARDS"
|
||||
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
@ -1,25 +0,0 @@
|
||||
From 6e08c9d08e9de59c7af28b720289debbbd384764 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Wang <13521008+isVoid@users.noreply.github.com>
|
||||
Date: Tue, 1 Apr 2025 17:28:05 -0700
|
||||
Subject: [PATCH] Avoid bumping certain driver API to avoid future breakage
|
||||
(#185)
|
||||
|
||||
Co-authored-by: isVoid <isVoid@users.noreply.github.com>
|
||||
---
|
||||
numba_cuda/numba/cuda/cudadrv/driver.py | 3 +++
|
||||
1 file changed, 3 insertions(+)
|
||||
|
||||
diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py
|
||||
index 1641bf77..233e9ed7 100644
|
||||
--- a/numba_cuda/numba/cuda/cudadrv/driver.py
|
||||
+++ b/numba_cuda/numba/cuda/cudadrv/driver.py
|
||||
@@ -365,6 +365,9 @@ def _find_api(self, fname):
|
||||
else:
|
||||
variants = ('_v2', '')
|
||||
|
||||
+ if fname in ("cuCtxGetDevice", "cuCtxSynchronize"):
|
||||
+ return getattr(self.lib, fname)
|
||||
+
|
||||
for variant in variants:
|
||||
try:
|
||||
return getattr(self.lib, f'{fname}{variant}')
|
||||
@ -32,16 +32,6 @@ if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /v
|
||||
git config --global --add safe.directory /var/lib/jenkins/workspace
|
||||
fi
|
||||
|
||||
|
||||
# Patch numba to avoid CUDA-13 crash, see https://github.com/pytorch/pytorch/issues/162878
|
||||
NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true)
|
||||
if [ -n "$NUMBA_CUDA_DIR" ]; then
|
||||
NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch"
|
||||
pushd "$NUMBA_CUDA_DIR"
|
||||
patch -p4 <"$NUMBA_PATCH"
|
||||
popd
|
||||
fi
|
||||
|
||||
echo "Environment variables:"
|
||||
env
|
||||
|
||||
@ -322,14 +312,14 @@ test_python_shard() {
|
||||
|
||||
# modify LD_LIBRARY_PATH to ensure it has the conda env.
|
||||
# This set of tests has been shown to be buggy without it for the split-build
|
||||
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests $INCLUDE_CLAUSE --shard "$1" "$NUM_TEST_SHARDS" --verbose $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests --exclude-quantization-tests $INCLUDE_CLAUSE --shard "$1" "$NUM_TEST_SHARDS" --verbose $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_python() {
|
||||
# shellcheck disable=SC2086
|
||||
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests $INCLUDE_CLAUSE --verbose $PYTHON_TEST_EXTRA_OPTION
|
||||
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests --exclude-quantization-tests $INCLUDE_CLAUSE --verbose $PYTHON_TEST_EXTRA_OPTION
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
@ -384,6 +374,7 @@ test_dynamo_wrapped_shard() {
|
||||
--exclude-distributed-tests \
|
||||
--exclude-torch-export-tests \
|
||||
--exclude-aot-dispatch-tests \
|
||||
--exclude-quantization-tests \
|
||||
--shard "$1" "$NUM_TEST_SHARDS" \
|
||||
--verbose \
|
||||
--upload-artifacts-while-running
|
||||
@ -1156,6 +1147,12 @@ test_distributed() {
|
||||
fi
|
||||
}
|
||||
|
||||
test_quantization() {
|
||||
echo "Testing quantization"
|
||||
|
||||
python test/test_quantization.py
|
||||
}
|
||||
|
||||
test_rpc() {
|
||||
echo "Testing RPC C++ tests"
|
||||
# NB: the ending test_rpc must match the current function name for the current
|
||||
@ -1582,7 +1579,6 @@ test_linux_aarch64() {
|
||||
python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \
|
||||
test_transformers test_multiprocessing test_numpy_interop test_autograd test_binary_ufuncs test_complex test_spectral_ops \
|
||||
test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops test_ops \
|
||||
distributed/elastic/timer/api_test distributed/elastic/timer/local_timer_example distributed/elastic/timer/local_timer_test \
|
||||
--shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose
|
||||
|
||||
# Dynamo tests
|
||||
@ -1657,6 +1653,8 @@ elif [[ "${TEST_CONFIG}" == *executorch* ]]; then
|
||||
test_executorch
|
||||
elif [[ "$TEST_CONFIG" == 'jit_legacy' ]]; then
|
||||
test_python_legacy_jit
|
||||
elif [[ "$TEST_CONFIG" == 'quantization' ]]; then
|
||||
test_quantization
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
|
||||
# TODO: run some C++ tests
|
||||
echo "no-op at the moment"
|
||||
|
||||
@ -25,7 +25,7 @@ echo Copying over test times file
|
||||
robocopy /E "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.additional_ci_files" "%PROJECT_DIR_WIN%\.additional_ci_files"
|
||||
|
||||
echo Run nn tests
|
||||
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
|
||||
python run_test.py --exclude-jit-executor --exclude-distributed-tests --exclude-quantization-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
|
||||
if ERRORLEVEL 1 goto fail
|
||||
|
||||
popd
|
||||
|
||||
2
.github/ci_commit_pins/vllm.txt
vendored
2
.github/ci_commit_pins/vllm.txt
vendored
@ -1 +1 @@
|
||||
d119fc86140785e7efc8f125c17153544d1e0f20
|
||||
973c9d01da863cac9c51e8a5c0d390fc84b84fbc
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -19,6 +19,7 @@ ciflow_push_tags:
|
||||
- ciflow/nightly
|
||||
- ciflow/periodic
|
||||
- ciflow/periodic-rocm-mi300
|
||||
- ciflow/quantization-periodic
|
||||
- ciflow/rocm
|
||||
- ciflow/rocm-mi300
|
||||
- ciflow/s390
|
||||
|
||||
14
.github/workflows/_docs.yml
vendored
14
.github/workflows/_docs.yml
vendored
@ -75,6 +75,10 @@ jobs:
|
||||
runner: ${{ inputs.runner_prefix }}linux.2xlarge
|
||||
# It takes less than 30m to finish python docs unless there are issues
|
||||
timeout-minutes: 30
|
||||
- docs_type: functorch
|
||||
runner: ${{ inputs.runner_prefix }}linux.2xlarge
|
||||
# It takes less than 15m to finish functorch docs unless there are issues
|
||||
timeout-minutes: 15
|
||||
# Set a fixed name for this job instead of using the current matrix-generated name, i.e. build-docs (cpp, linux.12xlarge, 180)
|
||||
# The current name requires updating the database last docs push query from test-infra every time the matrix is updated
|
||||
name: build-docs-${{ matrix.docs_type }}-${{ inputs.push }}
|
||||
@ -207,6 +211,16 @@ jobs:
|
||||
path: cppdocs/
|
||||
s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/cppdocs
|
||||
|
||||
- name: Upload functorch Docs Preview
|
||||
uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0
|
||||
if: ${{ github.event_name == 'pull_request' && matrix.docs_type == 'functorch' && steps.build-docs.outcome == 'success' }}
|
||||
with:
|
||||
retention-days: 14
|
||||
s3-bucket: doc-previews
|
||||
if-no-files-found: error
|
||||
path: functorch_ghpages/nightly/
|
||||
s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/functorchdocs
|
||||
|
||||
- name: Teardown Linux
|
||||
uses: pytorch/test-infra/.github/actions/teardown-linux@main
|
||||
if: always()
|
||||
|
||||
2
.github/workflows/_linux-test.yml
vendored
2
.github/workflows/_linux-test.yml
vendored
@ -169,7 +169,7 @@ jobs:
|
||||
id: install-nvidia-driver
|
||||
uses: pytorch/test-infra/.github/actions/setup-nvidia@main
|
||||
with:
|
||||
driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '580.82.07' }}
|
||||
driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }}
|
||||
|
||||
- name: Setup GPU_FLAG for docker run
|
||||
|
||||
33
.github/workflows/_rocm-test.yml
vendored
33
.github/workflows/_rocm-test.yml
vendored
@ -62,11 +62,6 @@ on:
|
||||
required: false
|
||||
type: number
|
||||
default: 1
|
||||
secrets:
|
||||
HUGGING_FACE_HUB_TOKEN:
|
||||
required: false
|
||||
description: |
|
||||
HF Auth token to avoid rate limits when downloading models or datasets from hub
|
||||
env:
|
||||
GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
|
||||
@ -81,9 +76,10 @@ jobs:
|
||||
strategy:
|
||||
matrix: ${{ fromJSON(inputs.test-matrix) }}
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.runner }}
|
||||
timeout-minutes: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }}
|
||||
runs-on: ${{ matrix.runner }}
|
||||
steps:
|
||||
# [see note: pytorch repo ref]
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
@ -135,9 +131,6 @@ jobs:
|
||||
|
||||
- name: Start monitoring script
|
||||
id: monitor-script
|
||||
if: ${{ !inputs.disable-monitor }}
|
||||
shell: bash
|
||||
continue-on-error: true
|
||||
env:
|
||||
JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
|
||||
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
|
||||
@ -145,6 +138,9 @@ jobs:
|
||||
WORKFLOW_RUN_ID: ${{github.run_id}}
|
||||
MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }}
|
||||
MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }}
|
||||
if: ${{ !inputs.disable-monitor }}
|
||||
shell: bash
|
||||
continue-on-error: true
|
||||
run: |
|
||||
python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7
|
||||
python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 &
|
||||
@ -182,12 +178,6 @@ jobs:
|
||||
run: |
|
||||
echo "timeout=$((JOB_TIMEOUT-30))" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Preserve github env variables for use in docker
|
||||
shell: bash
|
||||
run: |
|
||||
env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||
|
||||
- name: Test
|
||||
id: test
|
||||
env:
|
||||
@ -203,22 +193,20 @@ jobs:
|
||||
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
|
||||
BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }}
|
||||
TEST_CONFIG: ${{ matrix.config }}
|
||||
SHARD_NUMBER: ${{ matrix.shard }}
|
||||
NUM_TEST_SHARDS: ${{ matrix.num_shards }}
|
||||
REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }}
|
||||
CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }}
|
||||
VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }}
|
||||
TEST_SHOWLOCALS: ${{ steps.keep-going.outputs.ci-test-showlocals }}
|
||||
NO_TEST_TIMEOUT: ${{ steps.keep-going.outputs.ci-no-test-timeout }}
|
||||
NO_TD: ${{ steps.keep-going.outputs.ci-no-td }}
|
||||
TEST_CONFIG: ${{ matrix.config }}
|
||||
SHARD_NUMBER: ${{ matrix.shard }}
|
||||
NUM_TEST_SHARDS: ${{ matrix.num_shards }}
|
||||
REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }}
|
||||
DOCKER_IMAGE: ${{ inputs.docker-image }}
|
||||
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
|
||||
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
|
||||
TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }}
|
||||
DASHBOARD_TAG: ${{ inputs.dashboard-tag }}
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
|
||||
run: |
|
||||
set -x
|
||||
@ -248,7 +236,6 @@ jobs:
|
||||
-e GITHUB_RUN_ATTEMPT \
|
||||
-e JOB_ID \
|
||||
-e JOB_NAME \
|
||||
-e BASE_SHA \
|
||||
-e BRANCH \
|
||||
-e SHA1 \
|
||||
-e AWS_DEFAULT_REGION \
|
||||
@ -266,12 +253,10 @@ jobs:
|
||||
-e PYTORCH_TEST_CUDA_MEM_LEAK_CHECK \
|
||||
-e PYTORCH_TEST_RERUN_DISABLED_TESTS \
|
||||
-e TESTS_TO_INCLUDE \
|
||||
-e HUGGING_FACE_HUB_TOKEN \
|
||||
-e DASHBOARD_TAG \
|
||||
--env-file="${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" \
|
||||
--ulimit stack=10485760:83886080 \
|
||||
--ulimit core=0 \
|
||||
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||
--security-opt seccomp=unconfined \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--shm-size="8g" \
|
||||
|
||||
4
.github/workflows/build-vllm-wheel.yml
vendored
4
.github/workflows/build-vllm-wheel.yml
vendored
@ -178,12 +178,12 @@ jobs:
|
||||
contents: read
|
||||
container:
|
||||
image: continuumio/miniconda3:4.12.0
|
||||
environment: ${{ ((github.event_name == 'push' && github.event.ref == 'refs/heads/main') || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && 'nightly-wheel-upload' || '' }}
|
||||
environment: ${{ (github.event_name == 'push' && github.event.ref == 'refs/heads/main') && 'nightly-wheel-upload' || '' }}
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Configure AWS credentials(PyTorch account) for main
|
||||
if: ${{ (github.event_name == 'push' && github.event.ref == 'refs/heads/main') || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}
|
||||
if: ${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' }}
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::749337293305:role/gha_workflow_nightly_build_wheels
|
||||
|
||||
@ -43,11 +43,6 @@ on:
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
freezing:
|
||||
description: Run freezing?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
benchmark_configs:
|
||||
description: The list of configs used the benchmark
|
||||
required: false
|
||||
@ -107,7 +102,7 @@ jobs:
|
||||
if: github.event.schedule == '0 7 * * *'
|
||||
with:
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true-freezing-true
|
||||
dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true
|
||||
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
@ -121,9 +116,10 @@ jobs:
|
||||
name: inductor-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: inductor-build
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
with:
|
||||
build-environment: linux-jammy-py3.10-gcc11-build
|
||||
dashboard-tag: training-${{ inputs.training || 'false' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'true' }}-aotinductor-${{ inputs.aotinductor || 'true' }}-freezing-${{ inputs.freezing || 'true' }}
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}
|
||||
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}
|
||||
timeout-minutes: 720
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@ -105,7 +105,7 @@ jobs:
|
||||
# NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout
|
||||
# to run git rev-parse HEAD~:.ci/docker when a new image is needed
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
submodules: false
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
script: |
|
||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||
|
||||
2
.github/workflows/pull.yml
vendored
2
.github/workflows/pull.yml
vendored
@ -127,8 +127,6 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
# More memory is needed to build with asan
|
||||
runner: linux.2xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py3.10-clang18-asan
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
|
||||
|
||||
54
.github/workflows/quantization-periodic.yml
vendored
Normal file
54
.github/workflows/quantization-periodic.yml
vendored
Normal file
@ -0,0 +1,54 @@
|
||||
name: quantization-periodic
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/quantization-periodic/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
# run weekly
|
||||
- cron: "45 0 * * 0"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-default-label-prefix:
|
||||
name: get-default-label-prefix
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
periodic-quantization-build:
|
||||
name: periodic-quantization-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-default-label-prefix
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '8.9'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "quantization", shard: 1, num_shards: 1, runner: "${{ needs.get-default-label-prefix.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
periodic-test-quantization:
|
||||
name: periodic-test-quantization
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: periodic-quantization-build
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
docker-image: ${{ needs.periodic-quantization-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.periodic-quantization-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
2
.github/workflows/slow.yml
vendored
2
.github/workflows/slow.yml
vendored
@ -140,8 +140,6 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
# More memory is needed to build with asan
|
||||
runner: linux.2xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py3.10-clang18-asan
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
|
||||
|
||||
@ -891,7 +891,7 @@ IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
endif()
|
||||
|
||||
# Set USE_FBGEMM_GENAI to ON for CUDA build on SM100.
|
||||
if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8 AND NOT WIN32)
|
||||
if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||
message(STATUS "Setting USE_FBGEMM_GENAI to ON, doing CUDA build for SM100a")
|
||||
set(USE_FBGEMM_GENAI ON)
|
||||
endif()
|
||||
|
||||
@ -65,24 +65,14 @@ DLDataType getDLDataType(const Tensor& t) {
|
||||
break;
|
||||
// TODO(#146647): use macro here instead of spelling out each shell dtype
|
||||
case ScalarType::Float8_e5m2:
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e5m2;
|
||||
break;
|
||||
case ScalarType::Float8_e5m2fnuz:
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz;
|
||||
break;
|
||||
case ScalarType::Float8_e4m3fn:
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn;
|
||||
break;
|
||||
case ScalarType::Float8_e4m3fnuz:
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz;
|
||||
break;
|
||||
case ScalarType::Float8_e8m0fnu:
|
||||
dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu;
|
||||
TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::Float4_e2m1fn_x2:
|
||||
dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
|
||||
dtype.lanes = 2;
|
||||
dtype.bits = 4;
|
||||
TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack");
|
||||
break;
|
||||
case ScalarType::QInt8:
|
||||
case ScalarType::QUInt8:
|
||||
@ -187,11 +177,7 @@ static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* dat
|
||||
|
||||
ScalarType toScalarType(const DLDataType& dtype) {
|
||||
ScalarType stype = ScalarType::Undefined;
|
||||
if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
|
||||
TORCH_CHECK_BUFFER(
|
||||
dtype.lanes == 1,
|
||||
"ATen does not support lanes != 1 for dtype code", std::to_string(dtype.code));
|
||||
}
|
||||
TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1");
|
||||
switch (dtype.code) {
|
||||
case DLDataTypeCode::kDLUInt:
|
||||
switch (dtype.bits) {
|
||||
@ -283,73 +269,6 @@ ScalarType toScalarType(const DLDataType& dtype) {
|
||||
false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e5m2:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e5m2;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e5m2 bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e5m2fnuz:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e5m2fnuz;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e5m2fnuz bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e4m3fn:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e4m3fn;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e4m3fn bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e4m3fnuz:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e4m3fnuz;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e4m3fnuz bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat8_e8m0fnu:
|
||||
switch (dtype.bits) {
|
||||
case 8:
|
||||
stype = ScalarType::Float8_e8m0fnu;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat8_e8m0fnu bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
case DLDataTypeCode::kDLFloat4_e2m1fn:
|
||||
switch (dtype.bits) {
|
||||
case 4:
|
||||
switch (dtype.lanes) {
|
||||
case 2:
|
||||
stype = ScalarType::Float4_e2m1fn_x2;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat4_e2m1fn lanes ", std::to_string(dtype.lanes));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(
|
||||
false, "Unsupported kDLFloat4_e2m1fn bits ", std::to_string(dtype.bits));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code));
|
||||
}
|
||||
@ -435,8 +354,8 @@ T* toDLPackImpl(const Tensor& src) {
|
||||
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
|
||||
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
|
||||
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
|
||||
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
|
||||
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
|
||||
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
|
||||
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
|
||||
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
|
||||
fillVersion(&atDLMTensor->tensor);
|
||||
|
||||
|
||||
@ -102,7 +102,7 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
|
||||
// SparseTensorImpl has no storage, so we cannot query its nbytes.
|
||||
// (original_storage_size is only used for storage resizing in fsdp anyway, which does not apply to sparse)
|
||||
// Same for XLA
|
||||
if (base.unsafeGetTensorImpl()->has_storage() && data_ptr().device().type() != c10::DeviceType::XLA) {
|
||||
if (base.unsafeGetTensorImpl()->has_storage() && base.device().type() != c10::DeviceType::XLA) {
|
||||
original_storage_size_ = base.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
|
||||
} else {
|
||||
original_storage_size_ = -1;
|
||||
|
||||
@ -266,14 +266,11 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(
|
||||
* See Note [Acquire lock when using random generators]
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
|
||||
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
|
||||
state_->seed_ = seed;
|
||||
state_->philox_offset_per_thread_ = 0;
|
||||
no_reset_rnn_state_.clear();
|
||||
} else {
|
||||
TORCH_CHECK(state_->seed_ == seed, "CUDAGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.");
|
||||
// no-op case
|
||||
}
|
||||
at::cuda::assertNotCapturing(
|
||||
"Cannot call CUDAGeneratorImpl::set_current_seed");
|
||||
state_->seed_ = seed;
|
||||
state_->philox_offset_per_thread_ = 0;
|
||||
no_reset_rnn_state_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -302,6 +299,9 @@ uint64_t CUDAGeneratorImpl::get_offset() const {
|
||||
* Gets the current seed of CUDAGeneratorImpl.
|
||||
*/
|
||||
uint64_t CUDAGeneratorImpl::current_seed() const {
|
||||
// Debatable if current_seed() should be allowed in captured regions.
|
||||
// Conservatively disallow it for now.
|
||||
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed");
|
||||
return state_->seed_;
|
||||
}
|
||||
|
||||
@ -346,6 +346,8 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||
* and size of the internal state.
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
at::cuda::assertNotCapturing(
|
||||
"Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(int64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
@ -400,27 +402,15 @@ c10::intrusive_ptr<c10::GeneratorImpl> CUDAGeneratorImpl::graphsafe_get_state()
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
|
||||
// see Note [Why enforce RNG offset % 4 == 0?]
|
||||
|
||||
// Note: If you use CUDNN RNN's, calling
|
||||
// set_philox_offset_per_thread instead of set_offset will cause the
|
||||
// cudnn RNN rng state to become stale.
|
||||
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
|
||||
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
|
||||
state_->philox_offset_per_thread_ = offset;
|
||||
} else {
|
||||
state_->offset_intragraph_ = offset;
|
||||
}
|
||||
state_->philox_offset_per_thread_ = offset;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
|
||||
*/
|
||||
uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
|
||||
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
|
||||
return state_->philox_offset_per_thread_;
|
||||
} else {
|
||||
return state_->offset_intragraph_;
|
||||
}
|
||||
return state_->philox_offset_per_thread_;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
#define DLPACK_MAJOR_VERSION 1
|
||||
|
||||
/*! \brief The current minor version of dlpack */
|
||||
#define DLPACK_MINOR_VERSION 1
|
||||
#define DLPACK_MINOR_VERSION 0
|
||||
|
||||
/*! \brief DLPACK_DLL prefix for windows */
|
||||
#ifdef _WIN32
|
||||
@ -32,7 +32,9 @@
|
||||
#define DLPACK_DLL
|
||||
#endif
|
||||
|
||||
// NOLINTNEXTLINE(modernize-deprecated-headers)
|
||||
#include <stdint.h>
|
||||
// NOLINTNEXTLINE(modernize-deprecated-headers)
|
||||
#include <stddef.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
@ -157,26 +159,6 @@ typedef enum {
|
||||
kDLComplex = 5U,
|
||||
/*! \brief boolean */
|
||||
kDLBool = 6U,
|
||||
/*! \brief FP8 data types */
|
||||
kDLFloat8_e3m4 = 7U,
|
||||
kDLFloat8_e4m3 = 8U,
|
||||
kDLFloat8_e4m3b11fnuz = 9U,
|
||||
kDLFloat8_e4m3fn = 10U,
|
||||
kDLFloat8_e4m3fnuz = 11U,
|
||||
kDLFloat8_e5m2 = 12U,
|
||||
kDLFloat8_e5m2fnuz = 13U,
|
||||
kDLFloat8_e8m0fnu = 14U,
|
||||
/*! \brief FP6 data types
|
||||
* Setting bits != 6 is currently unspecified, and the producer must ensure it is set
|
||||
* while the consumer must stop importing if the value is unexpected.
|
||||
*/
|
||||
kDLFloat6_e2m3fn = 15U,
|
||||
kDLFloat6_e3m2fn = 16U,
|
||||
/*! \brief FP4 data types
|
||||
* Setting bits != 4 is currently unspecified, and the producer must ensure it is set
|
||||
* while the consumer must stop importing if the value is unexpected.
|
||||
*/
|
||||
kDLFloat4_e2m1fn = 17U,
|
||||
} DLDataTypeCode;
|
||||
|
||||
/*!
|
||||
@ -190,12 +172,6 @@ typedef enum {
|
||||
* - int8: type_code = 0, bits = 8, lanes = 1
|
||||
* - std::complex<float>: type_code = 5, bits = 64, lanes = 1
|
||||
* - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits)
|
||||
* - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory)
|
||||
* - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory)
|
||||
* - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory)
|
||||
*
|
||||
* When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e.,
|
||||
* for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element.
|
||||
*/
|
||||
typedef struct {
|
||||
/*!
|
||||
@ -253,12 +229,12 @@ typedef struct {
|
||||
/*! \brief The data type of the pointer*/
|
||||
DLDataType dtype;
|
||||
/*! \brief The shape of the tensor */
|
||||
int64_t* shape;
|
||||
const int64_t* shape;
|
||||
/*!
|
||||
* \brief strides of the tensor (in number of elements, not bytes)
|
||||
* can be NULL, indicating tensor is compact and row-majored.
|
||||
*/
|
||||
int64_t* strides;
|
||||
const int64_t* strides;
|
||||
/*! \brief The offset in bytes to the beginning pointer to data */
|
||||
uint64_t byte_offset;
|
||||
} DLTensor;
|
||||
@ -293,7 +269,7 @@ typedef struct DLManagedTensor {
|
||||
void (*deleter)(struct DLManagedTensor * self);
|
||||
} DLManagedTensor;
|
||||
|
||||
// bit masks used in the DLManagedTensorVersioned
|
||||
// bit masks used in in the DLManagedTensorVersioned
|
||||
|
||||
/*! \brief bit mask to indicate that the tensor is read only. */
|
||||
#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)
|
||||
@ -306,14 +282,6 @@ typedef struct DLManagedTensor {
|
||||
*/
|
||||
#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)
|
||||
|
||||
/*
|
||||
* \brief bit mask to indicate that whether a sub-byte type is packed or padded.
|
||||
*
|
||||
* The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can
|
||||
* be set by the producer to signal that a tensor of sub-byte type is padded.
|
||||
*/
|
||||
#define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL)
|
||||
|
||||
/*!
|
||||
* \brief A versioned and managed C Tensor object, manage memory of DLTensor.
|
||||
*
|
||||
|
||||
@ -171,8 +171,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
|
||||
POINTWISE_BOXED(fill_.Scalar);
|
||||
POINTWISE_BOXED(zero_);
|
||||
// This is special because this op doesn't return anything
|
||||
m.impl("_assert_tensor_metadata", native::_assert_tensor_metadata);
|
||||
|
||||
#undef UNARY_POINTWISE
|
||||
#undef UNARY_POINTWISE_ALL
|
||||
|
||||
@ -81,7 +81,7 @@ Tensor math_channel_shuffle(const Tensor& self, int64_t groups) {
|
||||
// TODO: contiguous can be made to preserve the memory format
|
||||
// of the input. However since the above reshape clobbers h and w
|
||||
// it may not be safe to do that, since channels_last contiguous
|
||||
// may think oc and the last dim correspond to h,w?
|
||||
// may think oc and and the last dim correspond to h,w?
|
||||
// It is not clear, however from initial looking around it feels that
|
||||
// this may not be correct.
|
||||
// In this case channels last will likely require custom implementation
|
||||
|
||||
@ -67,13 +67,13 @@ TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
|
||||
int64_t inputH = input_.size(heightDim);
|
||||
int64_t inputW = input_.size(widthDim);
|
||||
|
||||
TORCH_CHECK((poolSizeT <= inputT) && (outputT + poolSizeT - 1 < inputT),
|
||||
TORCH_CHECK(outputT + poolSizeT - 1 < inputT,
|
||||
"fractional_max_pool3d_out(): pool time ", poolSizeT,
|
||||
" too large relative to input time ", inputT);
|
||||
TORCH_CHECK((poolSizeW <= inputW) && (outputW + poolSizeW - 1 < inputW),
|
||||
TORCH_CHECK(outputW + poolSizeW - 1 < inputW,
|
||||
"fractional_max_pool3d_out(): pool width ", poolSizeW,
|
||||
" too large relative to input width ", inputW);
|
||||
TORCH_CHECK((poolSizeH <= inputH) && (outputH + poolSizeH - 1 < inputH),
|
||||
TORCH_CHECK(outputH + poolSizeH - 1 < inputH,
|
||||
"fractional_max_pool3d_out(): pool height ", poolSizeH,
|
||||
" too large relative to input height ", inputH);
|
||||
|
||||
|
||||
@ -52,7 +52,6 @@ void apply_triu_tril_single(
|
||||
int64_t self_col_stride,
|
||||
bool upper) {
|
||||
constexpr int64_t zero = 0;
|
||||
k = std::clamp(k, -n, m); // Clamp k to [-n, m] to prevent i + k arithmetic overflow, especially if k approaches INT64_MAX/INT64_MIN.
|
||||
|
||||
if (upper) {
|
||||
parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
|
||||
|
||||
@ -85,11 +85,11 @@ void cpu_max_unpool(
|
||||
if constexpr (is_3d) {
|
||||
TORCH_CHECK(false, "Found an invalid max index: ", optional_error_index.value(),
|
||||
" (output volumes are of size ", output_depth,
|
||||
"x", output_height, "x", output_width, ")");
|
||||
"x", output_height, "x", output_width);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Found an invalid max index: ", optional_error_index.value(),
|
||||
" (output volumes are of size ", output_height,
|
||||
"x", output_width, ")");
|
||||
"x", output_width);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -416,7 +416,6 @@ struct ReduceOp {
|
||||
if (config.should_block_y_reduce()) {
|
||||
value = block_y_reduce<output_vec_size>(value, shared_memory);
|
||||
}
|
||||
__syncthreads();
|
||||
if (config.should_block_x_reduce()) {
|
||||
value = block_x_reduce<output_vec_size>(value, shared_memory);
|
||||
}
|
||||
|
||||
@ -17,11 +17,12 @@ __global__ static void compute_cuda_kernel(
|
||||
index_t* result_ptr,
|
||||
int64_t size,
|
||||
int64_t result_size) {
|
||||
CUDA_KERNEL_ASSERT_PRINTF(
|
||||
result_size == cumsum_ptr[size - 1],
|
||||
if (C10_UNLIKELY((result_size != cumsum_ptr[size - 1]))) {
|
||||
printf("%s:%d:%s: block: [%d,%d,%d], thread: [%d,%d,%d] "
|
||||
"Invalid input! In `repeat_interleave`, the `output_size` argument (%ld) must be the same as the sum of the elements in the `repeats` tensor (%ld).\n",
|
||||
result_size,
|
||||
cumsum_ptr[size - 1]);
|
||||
__FILE__, __LINE__, __func__,blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, result_size, cumsum_ptr[size - 1 ]);
|
||||
CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1])
|
||||
}
|
||||
|
||||
int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
|
||||
|
||||
@ -5,20 +5,12 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
__global__ void weight_int8pack_mm_kernel(
|
||||
const float* x,
|
||||
const int8_t* w,
|
||||
const float* scale,
|
||||
float* out,
|
||||
int B,
|
||||
int K,
|
||||
int N) {
|
||||
__global__ void weight_int8pack_mm_kernel(const float* x, const int8_t* w, const float* scale, float* out, int B, int K, int N) {
|
||||
// one thread per output element: [B, N]
|
||||
int b = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int n = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (b >= B || n >= N)
|
||||
return;
|
||||
if (b >= B || n >= N) return;
|
||||
|
||||
float acc = 0.0f;
|
||||
for (int k = 0; k < K; ++k) {
|
||||
@ -28,11 +20,7 @@ __global__ void weight_int8pack_mm_kernel(
|
||||
out[b * N + n] = acc * scale[n];
|
||||
}
|
||||
|
||||
void launch_weight_int8pack_mm_cuda_kernel(
|
||||
const Tensor& x,
|
||||
const Tensor& w_int8,
|
||||
const Tensor& scale,
|
||||
Tensor& out) {
|
||||
void launch_weight_int8pack_mm_cuda_kernel(const Tensor& x, const Tensor& w_int8, const Tensor& scale, Tensor& out) {
|
||||
const int B = x.size(0);
|
||||
const int K = x.size(1);
|
||||
const int N = w_int8.size(0);
|
||||
@ -47,16 +35,12 @@ void launch_weight_int8pack_mm_cuda_kernel(
|
||||
w_int8.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(),
|
||||
out.data_ptr<float>(),
|
||||
B,
|
||||
K,
|
||||
N);
|
||||
B, K, N);
|
||||
}
|
||||
|
||||
|
||||
// Main GPU entry point
|
||||
at::Tensor _weight_int8pack_mm_cuda(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& w_int8,
|
||||
const at::Tensor& scale) {
|
||||
at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int8, const at::Tensor& scale) {
|
||||
// --- Check inputs ---
|
||||
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
|
||||
TORCH_CHECK(w_int8.is_cuda(), "w must be a CUDA tensor");
|
||||
@ -66,16 +50,12 @@ at::Tensor _weight_int8pack_mm_cuda(
|
||||
TORCH_CHECK(w_int8.dim() == 2, "w must be 2D");
|
||||
TORCH_CHECK(scale.dim() == 1, "scale must be 1D");
|
||||
|
||||
TORCH_CHECK(
|
||||
x.size(1) == w_int8.size(1),
|
||||
"K dimension mismatch: x.size(1) != w.size(1)");
|
||||
TORCH_CHECK(
|
||||
w_int8.size(0) == scale.size(0),
|
||||
"Output dim mismatch: w.size(0) != scale.size(0)");
|
||||
TORCH_CHECK(x.size(1) == w_int8.size(1), "K dimension mismatch: x.size(1) != w.size(1)");
|
||||
TORCH_CHECK(w_int8.size(0) == scale.size(0), "Output dim mismatch: w.size(0) != scale.size(0)");
|
||||
|
||||
// --- Determine shapes ---
|
||||
auto B = x.size(0); // batch size
|
||||
auto N = w_int8.size(0); // output dim
|
||||
auto B = x.size(0); // batch size
|
||||
auto N = w_int8.size(0); // output dim
|
||||
|
||||
// Ensure inputs are in the correct types for the kernel
|
||||
auto x_f32 = x.to(at::kFloat);
|
||||
@ -83,13 +63,12 @@ at::Tensor _weight_int8pack_mm_cuda(
|
||||
auto scale_f32 = scale.to(at::kFloat);
|
||||
|
||||
// --- Allocate output ---
|
||||
auto out = at::empty({B, N}, x_f32.options());
|
||||
auto out = at::empty({B, N}, x.options().dtype(at::kFloat));
|
||||
|
||||
// --- Launch kernel ---
|
||||
launch_weight_int8pack_mm_cuda_kernel(
|
||||
x_f32, w_int8_contiguous, scale_f32, out);
|
||||
launch_weight_int8pack_mm_cuda_kernel(x_f32, w_int8_contiguous, scale_f32, out);
|
||||
|
||||
return out.to(x.dtype());
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -482,9 +482,7 @@ auto build_graph(
|
||||
auto scaled_dot_product_flash_attention_options =
|
||||
fe::graph::SDPA_attributes()
|
||||
.set_name("CUDNN_SDPA")
|
||||
.set_is_inference(return_softmaxstats == false)
|
||||
// TODO(eqy): switch to this API once cuDNN FE is upgraded
|
||||
// .set_generate_stats(return_softmaxstats)
|
||||
.set_generate_stats(return_softmaxstats)
|
||||
.set_causal_mask(is_causal)
|
||||
.set_attn_scale(attn_scale);
|
||||
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
|
||||
@ -704,9 +702,7 @@ auto build_graph_nestedtensor(
|
||||
auto scaled_dot_product_flash_attention_options =
|
||||
fe::graph::SDPA_attributes()
|
||||
.set_name("CUDNN_SDPA_NESTEDTENSOR")
|
||||
.set_is_inference(return_softmaxstats == false)
|
||||
// TODO(eqy): switch to this API once cuDNN FE is upgraded
|
||||
// .set_generate_stats(return_softmaxstats)
|
||||
.set_generate_stats(return_softmaxstats)
|
||||
.set_causal_mask(is_causal)
|
||||
.set_attn_scale(attn_scale)
|
||||
.set_seq_len_q(SEQ_LEN_Q_)
|
||||
|
||||
@ -1770,12 +1770,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> miopen_depthwise_convolution_back
|
||||
// fusions
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
void raw_miopen_convolution_add_relu_out(
|
||||
void raw_miopen_convolution_relu_out(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& z,
|
||||
float alpha,
|
||||
const Tensor& bias,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
@ -1783,20 +1781,68 @@ void raw_miopen_convolution_add_relu_out(
|
||||
int64_t groups,
|
||||
bool benchmark,
|
||||
bool deterministic) {
|
||||
raw_miopen_convolution_forward_out(
|
||||
output,
|
||||
auto dataType = getMiopenDataType(input);
|
||||
miopenConvolutionMode_t c_mode = miopenConvolution;
|
||||
ConvolutionArgs args{ input, output, weight };
|
||||
args.handle = getMiopenHandle();
|
||||
at::MemoryFormat memory_format = miopen_conv_suggest_memory_format(input, weight);
|
||||
setConvolutionParams(
|
||||
&args.params,
|
||||
args.handle,
|
||||
input,
|
||||
weight,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
deterministic,
|
||||
memory_format);
|
||||
args.idesc.set(input, memory_format);
|
||||
args.wdesc.set(weight, memory_format, 0);
|
||||
args.odesc.set(output, memory_format);
|
||||
args.cdesc.set(
|
||||
dataType,
|
||||
c_mode,
|
||||
input.dim() - 2,
|
||||
args.params.padding,
|
||||
args.params.stride,
|
||||
args.params.dilation,
|
||||
args.params.groups,
|
||||
benchmark,
|
||||
deterministic);
|
||||
at::Tensor alpha_mul_z_add_bias =
|
||||
at::native::reshape_bias(input.dim(), bias).add(z, alpha);
|
||||
output.add_(alpha_mul_z_add_bias);
|
||||
output.relu_();
|
||||
|
||||
TensorDescriptor bdesc;
|
||||
bdesc.set(bias.expand({1, bias.size(0)}), output.dim());
|
||||
|
||||
// Create the fusion plan
|
||||
miopenFusionPlanDescriptor_t fusePlanDesc;
|
||||
miopenFusionOpDescriptor_t convoOp;
|
||||
miopenFusionOpDescriptor_t biasOp;
|
||||
miopenFusionOpDescriptor_t activOp;
|
||||
MIOPEN_CHECK(miopenCreateFusionPlan(&fusePlanDesc, miopenVerticalFusion, args.idesc.desc()));
|
||||
MIOPEN_CHECK(miopenCreateOpConvForward(fusePlanDesc, &convoOp, args.cdesc.desc(), args.wdesc.desc()));
|
||||
MIOPEN_CHECK(miopenCreateOpBiasForward(fusePlanDesc, &biasOp, bdesc.desc()));
|
||||
MIOPEN_CHECK(miopenCreateOpActivationForward(fusePlanDesc, &activOp, miopenActivationRELU));
|
||||
|
||||
// compile fusion plan
|
||||
MIOPEN_CHECK(miopenCompileFusionPlan(args.handle, fusePlanDesc));
|
||||
|
||||
// Set the Args
|
||||
float alpha = static_cast<float>(1);
|
||||
float beta = static_cast<float>(0);
|
||||
float activ_alpha = static_cast<float>(0);
|
||||
float activ_beta = static_cast<float>(0);
|
||||
float activ_gamma = static_cast<float>(0);
|
||||
miopenOperatorArgs_t fusionArgs;
|
||||
MIOPEN_CHECK(miopenCreateOperatorArgs(&fusionArgs));
|
||||
MIOPEN_CHECK(miopenSetOpArgsConvForward(fusionArgs, convoOp, &alpha, &beta, weight.const_data_ptr()));
|
||||
MIOPEN_CHECK(miopenSetOpArgsBiasForward(fusionArgs, biasOp, &alpha, &beta, bias.const_data_ptr()));
|
||||
MIOPEN_CHECK(miopenSetOpArgsActivForward(fusionArgs, activOp, &alpha, &beta, activ_alpha, activ_beta, activ_gamma));
|
||||
|
||||
miopenExecuteFusionPlan(args.handle, fusePlanDesc, args.idesc.desc(), input.const_data_ptr(), args.odesc.desc(), output.data_ptr(), fusionArgs);
|
||||
|
||||
// Cleanup
|
||||
miopenDestroyFusionPlan(fusePlanDesc);
|
||||
}
|
||||
|
||||
static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat memory_format) {
|
||||
@ -1809,107 +1855,171 @@ static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat m
|
||||
Tensor miopen_convolution_add_relu(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const Tensor& z_t,
|
||||
const Tensor& z,
|
||||
const std::optional<Scalar>& alpha,
|
||||
const std::optional<Tensor>& bias_t,
|
||||
const std::optional<Tensor>& bias,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
|
||||
const Tensor input = input_t.contiguous(memory_format);
|
||||
const Tensor weight = weight_t.contiguous(memory_format);
|
||||
Tensor z = z_t;
|
||||
if (z.suggest_memory_format() != memory_format) {
|
||||
z = z.to(memory_format);
|
||||
}
|
||||
z = z.contiguous(memory_format);
|
||||
|
||||
// FuseFrozenConvAddRelu performs some tensor shape checking
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input.sizes(), weight.sizes(), padding, stride, dilation),
|
||||
input.options().memory_format(memory_format));
|
||||
if (output_t.numel() == 0) {
|
||||
return output_t;
|
||||
}
|
||||
// MIOpen does not support fusion of add, the alpha2 * z step of the below cuDNN function:
|
||||
// y = act ( alpha1 * conv(x) + alpha2 * z + bias )
|
||||
|
||||
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
|
||||
|
||||
auto& ctx = at::globalContext();
|
||||
bool benchmark = ctx.benchmarkCuDNN();
|
||||
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
||||
auto _bias = bias_t.has_value()
|
||||
? bias_t.value()
|
||||
: at::zeros(
|
||||
{output_t.size(1)},
|
||||
optTypeMetaToScalarType(output_t.options().dtype_opt()),
|
||||
output_t.options().layout_opt(),
|
||||
output_t.options().device_opt(),
|
||||
output_t.options().pinned_memory_opt());
|
||||
|
||||
raw_miopen_convolution_add_relu_out(
|
||||
output_t,
|
||||
TensorArg input { input_t, "input", 1 },
|
||||
weight { weight_t, "weight", 2 };
|
||||
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
|
||||
input_t.options().memory_format(memory_format));
|
||||
if (output_t.numel() == 0){
|
||||
return output_t;
|
||||
}
|
||||
// Avoid ambiguity of "output" when this is being used as backwards
|
||||
TensorArg output{output_t, "result", 0};
|
||||
miopen_convolution_forward_out(
|
||||
output,
|
||||
"miopen_convolution_add_relu",
|
||||
input,
|
||||
weight,
|
||||
z,
|
||||
_alpha,
|
||||
_bias,
|
||||
stride,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark,
|
||||
true); // deterministic
|
||||
false // deterministic
|
||||
);
|
||||
|
||||
return output_t;
|
||||
auto contig_output_t = self_or_new_memory_format(output_t, memory_format);
|
||||
|
||||
if (!output_t.is_same(contig_output_t)) {
|
||||
contig_output_t.copy_(output_t);
|
||||
}
|
||||
|
||||
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
||||
auto _bias = bias.has_value()
|
||||
? bias.value()
|
||||
: at::zeros(
|
||||
{contig_output_t.size(1)},
|
||||
optTypeMetaToScalarType(contig_output_t.options().dtype_opt()),
|
||||
contig_output_t.options().layout_opt(),
|
||||
contig_output_t.options().device_opt(),
|
||||
contig_output_t.options().pinned_memory_opt());
|
||||
|
||||
at::Tensor alpha_mul_z_add_bias = at::native::reshape_bias(input_t.dim(), _bias).add(z, _alpha);
|
||||
contig_output_t.add_(alpha_mul_z_add_bias);
|
||||
contig_output_t.relu_();
|
||||
|
||||
return contig_output_t;
|
||||
}
|
||||
|
||||
Tensor miopen_convolution_relu(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_t,
|
||||
const std::optional<Tensor>& bias,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
|
||||
const Tensor input = input_t.contiguous(memory_format);
|
||||
const Tensor weight = weight_t.contiguous(memory_format);
|
||||
|
||||
// FuseFrozenConvAddRelu performs some tensor shape checking
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input.sizes(), weight.sizes(), padding, stride, dilation),
|
||||
input.options().memory_format(memory_format));
|
||||
if (output_t.numel() == 0) {
|
||||
return output_t;
|
||||
}
|
||||
|
||||
auto& ctx = at::globalContext();
|
||||
bool benchmark = ctx.benchmarkCuDNN();
|
||||
auto _bias = bias_t.has_value()
|
||||
? bias_t.value()
|
||||
: at::zeros(
|
||||
{output_t.size(1)},
|
||||
optTypeMetaToScalarType(output_t.options().dtype_opt()),
|
||||
output_t.options().layout_opt(),
|
||||
output_t.options().device_opt(),
|
||||
output_t.options().pinned_memory_opt());
|
||||
|
||||
raw_miopen_convolution_add_relu_out(
|
||||
output_t,
|
||||
input,
|
||||
weight,
|
||||
output_t, // use output_t as z to satisfy MIOpen API
|
||||
0, // alpha
|
||||
_bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark, // benchmark
|
||||
true); // deterministic
|
||||
// MIOpen currently only supports MemoryFormat::Contiguous and fp32 and 2d
|
||||
if (input_t.suggest_memory_format() == at::MemoryFormat::Contiguous
|
||||
&& input_t.scalar_type() == at::kFloat
|
||||
&& input_t.ndimension() == 4) {
|
||||
|
||||
return output_t;
|
||||
// FuseFrozenConvAddRelu performs some tensor shape checking
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
|
||||
input_t.options().memory_format(input_t.suggest_memory_format()));
|
||||
if (output_t.numel() == 0) {
|
||||
return output_t;
|
||||
}
|
||||
|
||||
auto _bias = bias.has_value()
|
||||
? bias.value()
|
||||
: at::zeros(
|
||||
{output_t.size(1)},
|
||||
optTypeMetaToScalarType(output_t.options().dtype_opt()),
|
||||
output_t.options().layout_opt(),
|
||||
output_t.options().device_opt(),
|
||||
output_t.options().pinned_memory_opt());
|
||||
|
||||
raw_miopen_convolution_relu_out(
|
||||
output_t,
|
||||
input_t,
|
||||
weight_t,
|
||||
_bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark, // benchmark
|
||||
false // deterministic
|
||||
);
|
||||
|
||||
return output_t;
|
||||
}
|
||||
else {
|
||||
// fallback
|
||||
|
||||
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
|
||||
|
||||
TensorArg input { input_t, "input", 1 },
|
||||
weight { weight_t, "weight", 2 };
|
||||
|
||||
Tensor output_t = at::detail::empty_cuda(
|
||||
conv_output_size(
|
||||
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
|
||||
input->options().memory_format(memory_format));
|
||||
if (output_t.numel() == 0){
|
||||
return output_t;
|
||||
}
|
||||
// Avoid ambiguity of "output" when this is being used as backwards
|
||||
TensorArg output{output_t, "result", 0};
|
||||
miopen_convolution_forward_out(
|
||||
output,
|
||||
"miopen_convolution_relu",
|
||||
input,
|
||||
weight,
|
||||
padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
benchmark,
|
||||
false // deterministic
|
||||
);
|
||||
|
||||
auto contig_output_t = self_or_new_memory_format(output_t, memory_format);
|
||||
|
||||
if (!output_t.is_same(contig_output_t)) {
|
||||
contig_output_t.copy_(output_t);
|
||||
}
|
||||
|
||||
auto _bias = bias.has_value()
|
||||
? bias.value()
|
||||
: at::zeros(
|
||||
{contig_output_t.size(1)},
|
||||
optTypeMetaToScalarType(contig_output_t.options().dtype_opt()),
|
||||
contig_output_t.options().layout_opt(),
|
||||
contig_output_t.options().device_opt(),
|
||||
contig_output_t.options().pinned_memory_opt());
|
||||
|
||||
at::Tensor reshaped_bias = at::native::reshape_bias(input_t.dim(), _bias);
|
||||
contig_output_t.add_(reshaped_bias);
|
||||
contig_output_t.relu_();
|
||||
|
||||
return contig_output_t;
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward)
|
||||
|
||||
@ -568,7 +568,7 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
|
||||
MPSShape* mpsStrides = getMPSShape(_tensor.strides());
|
||||
check_mps_shape(mpsShape);
|
||||
|
||||
auto storage_numel = src.storage().nbytes() / src.element_size() - src.storage_offset();
|
||||
auto storage_numel = src.storage().nbytes() / src.element_size();
|
||||
TORCH_CHECK(storage_numel <= std::numeric_limits<int32_t>::max(),
|
||||
"MPSGaph does not support tensor dims larger than INT_MAX");
|
||||
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType
|
||||
|
||||
@ -4372,7 +4372,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: narrow_copy_dense_cpu
|
||||
SparseCPU, SparseCUDA, SparseMPS: narrow_copy_sparse
|
||||
SparseCPU, SparseCUDA: narrow_copy_sparse
|
||||
CompositeExplicitAutogradNonFunctional: narrow_copy_dense_symint
|
||||
tags: view_copy
|
||||
|
||||
@ -6660,7 +6660,7 @@
|
||||
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: zeros_out
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: zeros_sparse_out
|
||||
SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out
|
||||
|
||||
- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
dispatch:
|
||||
@ -10699,7 +10699,6 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow
|
||||
CUDA: foreach_tensor_div_list_kernel_cuda
|
||||
MTIA: foreach_tensor_div_list_kernel_mtia
|
||||
|
||||
- func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10707,7 +10706,6 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow_
|
||||
CUDA: foreach_tensor_div_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_div_list_kernel_mtia_
|
||||
autogen: _foreach_div.List_out
|
||||
|
||||
- func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
@ -10731,7 +10729,6 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow
|
||||
CUDA: foreach_tensor_div_tensor_kernel_cuda
|
||||
MTIA: foreach_tensor_div_tensor_kernel_mtia
|
||||
|
||||
- func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10739,7 +10736,6 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow_
|
||||
CUDA: foreach_tensor_div_tensor_kernel_cuda_
|
||||
MTIA: foreach_tensor_div_tensor_kernel_mtia_
|
||||
autogen: _foreach_div.Tensor_out
|
||||
|
||||
- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -22,11 +21,10 @@
|
||||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
__global__ void ChooseQuantizationParamsKernelImpl(
|
||||
const int64_t* fake_quant_on,
|
||||
const T* x_min,
|
||||
const T* x_max,
|
||||
const float* x_min,
|
||||
const float* x_max,
|
||||
int32_t qmin,
|
||||
int32_t qmax,
|
||||
int size,
|
||||
@ -95,44 +93,34 @@ __global__ void ChooseQuantizationParamsKernelImpl(
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline bool isinf_device(float v) {
|
||||
return ::isinf(v);
|
||||
}
|
||||
__device__ inline bool isinf_device(c10::BFloat16 v) {
|
||||
return ::isinf(static_cast<float>(v));
|
||||
}
|
||||
|
||||
// CUDA kernel to compute Moving Average Min/Max of the tensor.
|
||||
// It uses the running_min and running_max along with averaging const, c.
|
||||
// The formula used to compute the new min/max is as follows
|
||||
//
|
||||
// running_min = (1 - c) * running_min + c * x_min, if running_min != inf
|
||||
// running_min = x_min, if running_min == inf
|
||||
template <typename T>
|
||||
__global__ void MovingAverageMinMax(
|
||||
const int64_t* observer_on,
|
||||
const T* x_min,
|
||||
const T* x_max,
|
||||
T* running_min,
|
||||
T* running_max,
|
||||
const float* x_min,
|
||||
const float* x_max,
|
||||
float* running_min,
|
||||
float* running_max,
|
||||
const float averaging_const,
|
||||
const int size) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (*observer_on == 1) {
|
||||
if (i < size) {
|
||||
T curr_min = x_min[i];
|
||||
T curr_max = x_max[i];
|
||||
float curr_min = x_min[i];
|
||||
float curr_max = x_max[i];
|
||||
|
||||
T averaging_const_t = static_cast<T>(averaging_const);
|
||||
float adjusted_min = ::isinf(running_min[i])
|
||||
? curr_min
|
||||
: (running_min[i]) + averaging_const * (curr_min - (running_min[i]));
|
||||
|
||||
T adjusted_min = isinf_device(running_min[i]) ? curr_min
|
||||
: (running_min[i]) +
|
||||
averaging_const_t * (curr_min - (running_min[i]));
|
||||
|
||||
T adjusted_max = isinf_device(running_max[i]) ? curr_max
|
||||
: (running_max[i]) +
|
||||
averaging_const_t * (curr_max - (running_max[i]));
|
||||
float adjusted_max = ::isinf(running_max[i])
|
||||
? curr_max
|
||||
: (running_max[i]) + averaging_const * (curr_max - (running_max[i]));
|
||||
|
||||
running_min[i] = adjusted_min;
|
||||
running_max[i] = adjusted_max;
|
||||
@ -154,51 +142,40 @@ void _calculate_moving_average(
|
||||
at::Tensor x_min, x_max;
|
||||
|
||||
int64_t* observer_on_data = observer_on.data_ptr<int64_t>();
|
||||
float* running_min_data = running_min.data_ptr<float>();
|
||||
float* running_max_data = running_max.data_ptr<float>();
|
||||
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (per_row_fq) {
|
||||
std::tie(x_min, x_max) = at::aminmax(x, 1);
|
||||
float* x_min_data = x_min.data_ptr<float>();
|
||||
float* x_max_data = x_max.data_ptr<float>();
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
|
||||
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
|
||||
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
|
||||
// Moving Average Min/Max observer for activations
|
||||
MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>(
|
||||
observer_on_data,
|
||||
x_min_data,
|
||||
x_max_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
averaging_const,
|
||||
size);
|
||||
});
|
||||
// Moving Average Min/Max observer for activations
|
||||
MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>(
|
||||
observer_on_data,
|
||||
x_min_data,
|
||||
x_max_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
averaging_const,
|
||||
size);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
std::tie(x_min, x_max) = at::aminmax(x);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
|
||||
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
|
||||
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
|
||||
// Moving Average Min/Max observer for activations
|
||||
MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>(
|
||||
observer_on_data,
|
||||
x_min_data,
|
||||
x_max_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
averaging_const,
|
||||
1 /*size*/);
|
||||
});
|
||||
float* x_min_data = x_min.data_ptr<float>();
|
||||
float* x_max_data = x_max.data_ptr<float>();
|
||||
// Moving Average Min/Max observer for activations
|
||||
MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>(
|
||||
observer_on_data,
|
||||
x_min_data,
|
||||
x_max_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
averaging_const,
|
||||
1 /*size*/);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
}
|
||||
@ -221,44 +198,34 @@ void _calc_moving_avg_qparams_helper(
|
||||
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
||||
int64_t* fake_quant_on_data = fake_quant_on.data_ptr<int64_t>();
|
||||
if (per_row_fq) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
|
||||
ChooseQuantizationParamsKernelImpl<<<
|
||||
num_blocks,
|
||||
num_threads,
|
||||
0,
|
||||
cuda_stream>>>(
|
||||
fake_quant_on_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
qmin,
|
||||
qmax,
|
||||
size,
|
||||
symmetric_quant,
|
||||
scale_ptr,
|
||||
zp_ptr);
|
||||
});
|
||||
float* running_min_data = running_min.data_ptr<float>();
|
||||
float* running_max_data = running_max.data_ptr<float>();
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
|
||||
ChooseQuantizationParamsKernelImpl<<<num_blocks, num_threads, 0, cuda_stream>>>(
|
||||
fake_quant_on_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
qmin,
|
||||
qmax,
|
||||
size,
|
||||
symmetric_quant,
|
||||
scale_ptr,
|
||||
zp_ptr);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
|
||||
fake_quant_on_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
qmin,
|
||||
qmax,
|
||||
1, // size
|
||||
symmetric_quant, // preserve_sparsity
|
||||
scale_ptr,
|
||||
zp_ptr);
|
||||
});
|
||||
float* running_min_data = running_min.data_ptr<float>();
|
||||
float* running_max_data = running_max.data_ptr<float>();
|
||||
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
|
||||
fake_quant_on_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
qmin,
|
||||
qmax,
|
||||
1, // size
|
||||
symmetric_quant, // preserve_sparsity
|
||||
scale_ptr,
|
||||
zp_ptr);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
}
|
||||
|
||||
@ -42,7 +42,7 @@ TEST(MPSObjCInterfaceTest, MPSCustomKernel) {
|
||||
id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: [NSString stringWithUTF8String:CUSTOM_KERNEL]
|
||||
options: nil
|
||||
error: &error];
|
||||
TORCH_CHECK(customKernelLibrary, "Failed to create custom kernel library, error: ", error.localizedDescription.UTF8String);
|
||||
TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String);
|
||||
|
||||
id<MTLFunction> customFunction = [customKernelLibrary newFunctionWithName: @"add_arrays"];
|
||||
TORCH_CHECK(customFunction, "Failed to create function state object for the kernel");
|
||||
|
||||
@ -76,23 +76,4 @@ int32_t getGlobalIdxFromDevice(DeviceIndex device) {
|
||||
return device_global_idxs[device];
|
||||
}
|
||||
|
||||
// Check if a device can access the memory of a peer device directly.
|
||||
bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer) {
|
||||
if (device == -1) {
|
||||
device = c10::xpu::current_device();
|
||||
}
|
||||
if (peer == -1) {
|
||||
peer = c10::xpu::current_device();
|
||||
}
|
||||
check_device_index(device);
|
||||
check_device_index(peer);
|
||||
// A device can always access itself
|
||||
if (device == peer) {
|
||||
return true;
|
||||
}
|
||||
return c10::xpu::get_raw_device(device).ext_oneapi_can_access_peer(
|
||||
c10::xpu::get_raw_device(peer),
|
||||
sycl::ext::oneapi::peer_access::access_supported);
|
||||
}
|
||||
|
||||
} // namespace at::xpu
|
||||
|
||||
@ -17,6 +17,4 @@ TORCH_XPU_API DeviceProp* getDeviceProperties(DeviceIndex device);
|
||||
|
||||
TORCH_XPU_API int32_t getGlobalIdxFromDevice(DeviceIndex device);
|
||||
|
||||
TORCH_XPU_API bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer);
|
||||
|
||||
} // namespace at::xpu
|
||||
|
||||
@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
|
||||
|
||||
|
||||
YituTechConvBert,pass,0
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,pass,6
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,pass,5
|
||||
|
||||
|
@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,5
|
||||
|
||||
|
||||
YituTechConvBert,pass,5
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,eager_fail_to_run,0
|
||||
|
||||
|
@ -167,23 +167,3 @@ XLNetLMHeadModel,pass,0
|
||||
|
||||
|
||||
YituTechConvBert,pass,0
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,fail_accuracy,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,fail_accuracy,0
|
||||
|
||||
|
@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
|
||||
|
||||
|
||||
YituTechConvBert,pass,0
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,pass,6
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,pass,5
|
||||
|
||||
|
@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,5
|
||||
|
||||
|
||||
YituTechConvBert,pass,5
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,eager_fail_to_run,0
|
||||
|
||||
|
@ -205,7 +205,7 @@ llama,pass,0
|
||||
|
||||
|
||||
|
||||
llama_v2_7b_16h,pass_due_to_skip,0
|
||||
llama_v2_7b_16h,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
|
||||
|
||||
|
||||
YituTechConvBert,pass,0
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,pass,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,pass,6
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,pass,5
|
||||
|
||||
|
@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,5
|
||||
|
||||
|
||||
YituTechConvBert,pass,5
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,eager_fail_to_run,0
|
||||
|
||||
|
@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
|
||||
|
||||
|
||||
YituTechConvBert,pass,0
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,pass,6
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,pass,5
|
||||
|
||||
|
@ -205,7 +205,7 @@ llama,pass,0
|
||||
|
||||
|
||||
|
||||
llama_v2_7b_16h,pass_due_to_skip,0
|
||||
llama_v2_7b_16h,model_fail_to_load,0
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
|
||||
|
||||
|
||||
YituTechConvBert,pass,0
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,pass,5
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,pass_due_to_skip,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,pass,6
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,pass,5
|
||||
|
||||
|
@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,5
|
||||
|
||||
|
||||
YituTechConvBert,pass,5
|
||||
|
||||
|
||||
|
||||
meta-llama/Llama-3.2-1B,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
google/gemma-2-2b,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
google/gemma-3-4b-it,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
openai/whisper-tiny,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
Qwen/Qwen3-0.6B,eager_fail_to_run,0
|
||||
|
||||
|
@ -61,22 +61,6 @@ struct C10_API Storage {
|
||||
allocator,
|
||||
resizable)) {}
|
||||
|
||||
// Creates storage with pre-allocated memory buffer. Allocator is given for
|
||||
// potential future reallocations, however it can be nullptr if the storage
|
||||
// is non-resizable
|
||||
Storage(
|
||||
use_byte_size_t /*use_byte_size*/,
|
||||
SymInt size_bytes,
|
||||
at::DataPtr data_ptr,
|
||||
at::Allocator* allocator = nullptr,
|
||||
bool resizable = false)
|
||||
: storage_impl_(c10::make_intrusive<StorageImpl>(
|
||||
StorageImpl::use_byte_size_t(),
|
||||
std::move(size_bytes),
|
||||
std::move(data_ptr),
|
||||
allocator,
|
||||
resizable)) {}
|
||||
|
||||
protected:
|
||||
explicit Storage(unsafe_borrow_t, const Storage& rhs)
|
||||
: storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim(
|
||||
|
||||
@ -3269,7 +3269,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
|
||||
is_le<sizeof(autograd_meta_), 16, FieldNameEnum::autograd_meta_>();
|
||||
is_le<sizeof(extra_meta_), 16, FieldNameEnum::extra_meta_>();
|
||||
are_equal<sizeof(version_counter_), 8, FieldNameEnum::version_counter_>();
|
||||
are_equal<sizeof(pyobj_slot_), 8, FieldNameEnum::pyobj_slot_>();
|
||||
are_equal<sizeof(pyobj_slot_), 16, FieldNameEnum::pyobj_slot_>();
|
||||
are_equal<sizeof(sizes_and_strides_), 88, FieldNameEnum::sizes_and_strides_>();
|
||||
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
|
||||
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
|
||||
|
||||
@ -13,10 +13,11 @@ struct C10_API PyInterpreterHooksInterface {
|
||||
|
||||
// Get the PyInterpreter instance
|
||||
// Stub implementation throws error when Python is not available
|
||||
// We return nullptr rather than throwing an error since there are bits of c10
|
||||
// that expect an empty PyObjectSlot when python is not available.
|
||||
virtual PyInterpreter* getPyInterpreter() const {
|
||||
return nullptr;
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"PyTorch was compiled without Python support. "
|
||||
"Cannot access Python interpreter from C++.");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
PyObjectSlot::PyObjectSlot() : pyobj_(nullptr) {}
|
||||
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
|
||||
PyObjectSlot::~PyObjectSlot() {
|
||||
maybe_destroy_pyobj();
|
||||
@ -10,9 +10,9 @@ PyObjectSlot::~PyObjectSlot() {
|
||||
|
||||
void PyObjectSlot::maybe_destroy_pyobj() {
|
||||
if (owns_pyobj()) {
|
||||
TORCH_INTERNAL_ASSERT(getGlobalPyInterpreter() != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
|
||||
(*getGlobalPyInterpreter())
|
||||
(*pyobj_interpreter_.load(std::memory_order_acquire))
|
||||
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
|
||||
// NB: this destructor can only be entered when there are no
|
||||
// references to this C++ object (obviously), NOR any references
|
||||
@ -25,7 +25,7 @@ void PyObjectSlot::maybe_destroy_pyobj() {
|
||||
}
|
||||
|
||||
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
|
||||
return getGlobalPyInterpreter();
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
|
||||
@ -35,7 +35,7 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
|
||||
}
|
||||
|
||||
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
|
||||
auto interpreter = getGlobalPyInterpreter();
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter) {
|
||||
return *interpreter;
|
||||
}
|
||||
|
||||
@ -6,17 +6,10 @@
|
||||
#include <c10/util/python_stub.h>
|
||||
#include <optional>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
// Function pointer type for getting the global interpreter
|
||||
using GetPyInterpreterFn = PyInterpreter* (*)();
|
||||
|
||||
// Global function pointer (set by csrc initialization)
|
||||
C10_API extern GetPyInterpreterFn g_get_pyinterpreter_fn;
|
||||
|
||||
// Helper function to get the global interpreter
|
||||
C10_API PyInterpreter* getGlobalPyInterpreter();
|
||||
|
||||
struct C10_API PyObjectSlot {
|
||||
public:
|
||||
PyObjectSlot();
|
||||
@ -33,6 +26,8 @@ struct C10_API PyObjectSlot {
|
||||
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
|
||||
// PyObject if necessary!
|
||||
void init_pyobj(PyObject* pyobj) {
|
||||
pyobj_interpreter_.store(
|
||||
getGlobalPyInterpreter(), std::memory_order_relaxed);
|
||||
pyobj_ = pyobj;
|
||||
}
|
||||
|
||||
@ -60,15 +55,18 @@ struct C10_API PyObjectSlot {
|
||||
|
||||
// @todo alban: I'm not too sure what's going on here, we can probably delete
|
||||
// it but it's worthwhile making sure
|
||||
std::optional<PyObject*> check_pyobj() const {
|
||||
impl::PyInterpreter* interpreter = getGlobalPyInterpreter();
|
||||
if (interpreter == nullptr || pyobj_ == nullptr) {
|
||||
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
|
||||
impl::PyInterpreter* interpreter =
|
||||
pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
|
||||
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
|
||||
PyInterpreter& load_pyobj_interpreter() const;
|
||||
@ -78,6 +76,30 @@ struct C10_API PyObjectSlot {
|
||||
void set_owns_pyobj(bool b);
|
||||
|
||||
private:
|
||||
// This field contains the interpreter tag for this object. See
|
||||
// Note [Python interpreter tag] for general context
|
||||
//
|
||||
// Note [Memory ordering on Python interpreter tag]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// What memory_order do we need when accessing this atomic? We don't
|
||||
// need a single total modification order (as provided by
|
||||
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
|
||||
// transition from -1 to some positive integer and never changes afterwards.
|
||||
// Because there is only one modification, it trivially already has a total
|
||||
// modification order (e.g., we don't need fences or locked instructions on
|
||||
// x86)
|
||||
//
|
||||
// In fact, one could make a reasonable argument that relaxed reads are OK,
|
||||
// due to the presence of external locking (GIL) to ensure that interactions
|
||||
// with other data structures are still correctly synchronized, so that
|
||||
// we fall in the "Single-Location Data Structures" case as described in
|
||||
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
|
||||
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
|
||||
// as I get the same assembly in both cases. So I just use the more
|
||||
// conservative acquire (which will impede compiler optimizations but I don't
|
||||
// care)
|
||||
std::atomic<PyInterpreter*> pyobj_interpreter_;
|
||||
|
||||
// This field contains a reference to a PyObject representing this Tensor.
|
||||
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
|
||||
// PyObject for it and set this field. This field does not have to be
|
||||
|
||||
@ -52,7 +52,7 @@ struct maybe_bool {
|
||||
template <typename src_t>
|
||||
struct maybe_bool<true, src_t> {
|
||||
C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) {
|
||||
// Don't use bool operator so as to also compile for ComplexHalf.
|
||||
// Don't use bool operator so as to to also compile for ComplexHalf.
|
||||
return src.real() || src.imag();
|
||||
}
|
||||
};
|
||||
|
||||
@ -1,243 +0,0 @@
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(CAFFE2_PERF_WITH_SVE128)
|
||||
#include <arm_neon.h>
|
||||
#include <arm_neon_sve_bridge.h>
|
||||
#include <arm_sve.h>
|
||||
|
||||
#include "c10/macros/Macros.h"
|
||||
|
||||
// Log and exp approximations inspired from ACL implementation
|
||||
|
||||
inline float32x4_t vtaylor_polyq_for_log_f32(float32x4_t x) {
|
||||
const float32x4_t log_tab_1 = vdupq_n_f32(-2.29561495781f);
|
||||
const float32x4_t log_tab_2 = vdupq_n_f32(-2.47071170807f);
|
||||
const float32x4_t log_tab_3 = vdupq_n_f32(-5.68692588806f);
|
||||
const float32x4_t log_tab_4 = vdupq_n_f32(-0.165253549814f);
|
||||
const float32x4_t log_tab_5 = vdupq_n_f32(5.17591238022f);
|
||||
const float32x4_t log_tab_6 = vdupq_n_f32(0.844007015228f);
|
||||
const float32x4_t log_tab_7 = vdupq_n_f32(4.58445882797f);
|
||||
const float32x4_t log_tab_8 = vdupq_n_f32(0.0141278216615f);
|
||||
|
||||
float32x4_t A = vmlaq_f32(log_tab_1, log_tab_5, x);
|
||||
float32x4_t B = vmlaq_f32(log_tab_3, log_tab_7, x);
|
||||
float32x4_t C = vmlaq_f32(log_tab_2, log_tab_6, x);
|
||||
float32x4_t x2 = vmulq_f32(x, x);
|
||||
float32x4_t D = svget_neonq(svmad_f32_x(
|
||||
svptrue_b8(),
|
||||
svset_neonq(svundef_f32(), x),
|
||||
svset_neonq(svundef_f32(), log_tab_8),
|
||||
svset_neonq(svundef_f32(), log_tab_4)));
|
||||
float32x4_t x4 = vmulq_f32(x2, x2);
|
||||
float32x4_t res = vmlaq_f32(vmlaq_f32(A, B, x2), vmlaq_f32(C, D, x2), x4);
|
||||
return res;
|
||||
}
|
||||
|
||||
inline float32x4_t vlogq_f32(float32x4_t x) {
|
||||
const float32x4_t CONST_LN2 = vdupq_n_f32(0.6931471805f); // ln(2)
|
||||
|
||||
// Extract exponent
|
||||
int32x4_t m = svget_neonq(svsub_n_s32_x(
|
||||
svptrue_b8(),
|
||||
svset_neonq(
|
||||
svundef_s32(),
|
||||
vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_f32(x), 23))),
|
||||
127));
|
||||
float32x4_t val = vreinterpretq_f32_s32(
|
||||
vsubq_s32(vreinterpretq_s32_f32(x), vshlq_n_s32(m, 23)));
|
||||
|
||||
// Polynomial Approximation
|
||||
float32x4_t poly = vtaylor_polyq_for_log_f32(val);
|
||||
|
||||
// Reconstruct
|
||||
poly = vmlaq_f32(poly, vcvtq_f32_s32(m), CONST_LN2);
|
||||
|
||||
return poly;
|
||||
}
|
||||
|
||||
inline float32x4_t vexpq_f32(float32x4_t x) {
|
||||
const auto c1 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3f7ffff6)));
|
||||
const auto c2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3efffedb)));
|
||||
const auto c3 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3e2aaf33)));
|
||||
const auto c4 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3d2b9f17)));
|
||||
const auto c5 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3c072010)));
|
||||
|
||||
const auto shift = vreinterpretq_f32_u32(
|
||||
svget_neonq(svdup_n_u32(0x4b00007f))); // 2^23 + 127 = 0x1.0000fep23f
|
||||
const auto inv_ln2 = vreinterpretq_f32_u32(
|
||||
svget_neonq(svdup_n_u32(0x3fb8aa3b))); // 1 / ln(2) = 0x1.715476p+0f
|
||||
const auto neg_ln2_hi = vreinterpretq_f32_u32(svget_neonq(
|
||||
svdup_n_u32(0xbf317200))); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
|
||||
const auto neg_ln2_lo = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(
|
||||
0xb5bfbe8e))); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
|
||||
|
||||
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
|
||||
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
|
||||
const auto zero = svdup_n_f32(0.f);
|
||||
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
|
||||
|
||||
// Range reduction:
|
||||
// e^x = 2^n * e^r
|
||||
// where:
|
||||
// n = floor(x / ln(2))
|
||||
// r = x - n * ln(2)
|
||||
//
|
||||
// By adding x / ln(2) with 2^23 + 127 (shift):
|
||||
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
|
||||
// forces decimal part
|
||||
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n)
|
||||
// + 127 will occupy the whole fraction part of z in FP32 format.
|
||||
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
|
||||
// of x / ln(2) (i.e. n) because the decimal part has been pushed out and
|
||||
// lost.
|
||||
// * The addition of 127 makes the FP32 fraction part of z ready to be used
|
||||
// as the exponent
|
||||
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
|
||||
const auto z = vfmaq_f32(shift, x, inv_ln2);
|
||||
const auto n = z - shift;
|
||||
const auto scale =
|
||||
vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n
|
||||
|
||||
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
|
||||
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in term
|
||||
// of accuracy and performance.
|
||||
const auto r_hi = vfmaq_f32(x, n, neg_ln2_hi);
|
||||
const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo);
|
||||
|
||||
// Compute the truncated Taylor series of e^r.
|
||||
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
|
||||
const auto r2 = r * r;
|
||||
|
||||
const auto p1 = c1 * r;
|
||||
const auto p23 = vfmaq_f32(c2, c3, r);
|
||||
const auto p45 = vfmaq_f32(c4, c5, r);
|
||||
const auto p2345 = vfmaq_f32(p23, p45, r2);
|
||||
const auto p12345 = vfmaq_f32(p1, p2345, r2);
|
||||
|
||||
auto poly = svset_neonq(svundef_f32(), vfmaq_f32(scale, p12345, scale));
|
||||
|
||||
// Handle underflow and overflow.
|
||||
poly = svsel_f32(
|
||||
svcmplt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), min_input),
|
||||
zero,
|
||||
poly);
|
||||
poly = svsel_f32(
|
||||
svcmpgt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), max_input),
|
||||
inf,
|
||||
poly);
|
||||
|
||||
return svget_neonq(poly);
|
||||
}
|
||||
|
||||
// ln(x) = log2(x) * ln(2)
|
||||
// pow(x, n) = exp(n * ln(x))
|
||||
inline float32x4_t compute_batch_box_cox_vec_sve128_float(
|
||||
svfloat32_t lambda1_v,
|
||||
svfloat32_t lambda2_v,
|
||||
svfloat32_t data_v,
|
||||
svfloat32_t k_eps) {
|
||||
// sum_v = lambda2_v + data_v
|
||||
float32x4_t sum_v = vaddq_f32(svget_neonq(data_v), svget_neonq(lambda2_v));
|
||||
|
||||
// test lambda1_v: predNZ == 1 iff lambda1_v != 0
|
||||
svbool_t predNZ = svcmpne_n_f32(svptrue_b8(), lambda1_v, 0.0f);
|
||||
|
||||
// clamp sum_v: sum_v = max(sum_v, k_eps)
|
||||
sum_v = vmaxq_f32(sum_v, svget_neonq(k_eps));
|
||||
|
||||
// lnData = log(sum_v)
|
||||
svfloat32_t lnData = svset_neonq(svundef_f32(), vlogq_f32(sum_v));
|
||||
|
||||
// if any lambda1 != 0, compute pow(sum_v, lambda1) using lnData
|
||||
// pow(sum_v, lambda1) == exp(lambda1 * ln(sum_v))
|
||||
if (C10_LIKELY(svptest_any(predNZ, predNZ))) {
|
||||
// mult = lambda1 * ln(sum_v)
|
||||
float32x4_t mult = vmulq_f32(svget_neonq(lnData), svget_neonq(lambda1_v));
|
||||
|
||||
// lambda1_r = 1 / lambda1
|
||||
svfloat32_t lambda1_r = svdivr_f32_m(predNZ, lambda1_v, svdup_n_f32(1.0f));
|
||||
|
||||
// pow = exp(mult)
|
||||
float32x4_t pow = vexpq_f32(mult);
|
||||
|
||||
// merge results
|
||||
// lnData if lambda1 == 0, (lambda1_r * pow - lambda1_r) if lambda1 != 0
|
||||
lnData = svsel_f32(predNZ, lambda1_r, lnData);
|
||||
lnData =
|
||||
svnmsb_f32_m(predNZ, lnData, svset_neonq(svundef_f32(), pow), lnData);
|
||||
}
|
||||
return svget_neonq(lnData);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void compute_batch_box_cox_vec_sve128(
|
||||
std::size_t N,
|
||||
std::size_t D,
|
||||
const T* data_ptr,
|
||||
const T* __restrict lambda1_ptr,
|
||||
const T* __restrict lambda2_ptr,
|
||||
T* output_ptr);
|
||||
|
||||
template <>
|
||||
void compute_batch_box_cox_vec_sve128(
|
||||
std::size_t N,
|
||||
std::size_t D,
|
||||
const float* data_ptr,
|
||||
const float* __restrict lambda1_ptr,
|
||||
const float* __restrict lambda2_ptr,
|
||||
float* output_ptr) {
|
||||
svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6));
|
||||
|
||||
std::size_t remainder = D % 4;
|
||||
std::size_t loopBound = D - remainder;
|
||||
svbool_t remainderPred = svwhilelt_b32_u64(0, remainder);
|
||||
|
||||
for (; C10_LIKELY(N > 0); --N) {
|
||||
for (std::size_t j = 0; C10_LIKELY(j != loopBound);
|
||||
j += 4, data_ptr += 4, output_ptr += 4) {
|
||||
svfloat32_t lambda1_v =
|
||||
svset_neonq(svundef_f32(), vld1q_f32(lambda1_ptr + j));
|
||||
svfloat32_t lambda2_v =
|
||||
svset_neonq(svundef_f32(), vld1q_f32(lambda2_ptr + j));
|
||||
svfloat32_t data_v = svset_neonq(svundef_f32(), vld1q_f32(data_ptr));
|
||||
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
lambda1_v, lambda2_v, data_v, k_eps);
|
||||
vst1q_f32(output_ptr, result);
|
||||
}
|
||||
if (C10_LIKELY(remainder > 0)) {
|
||||
svfloat32_t lambda1_v = svld1_f32(remainderPred, lambda1_ptr + loopBound);
|
||||
svfloat32_t lambda2_v = svld1_f32(remainderPred, lambda2_ptr + loopBound);
|
||||
svfloat32_t data_v = svld1_f32(remainderPred, data_ptr);
|
||||
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
|
||||
lambda1_v, lambda2_v, data_v, k_eps);
|
||||
svst1_f32(remainderPred, output_ptr, svset_neonq(svundef_f32(), result));
|
||||
data_ptr += remainder;
|
||||
output_ptr += remainder;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace caffe2::details {
|
||||
|
||||
template <typename T>
|
||||
void compute_batch_box_cox__sve128(
|
||||
std::size_t N,
|
||||
std::size_t D,
|
||||
const T* self_data,
|
||||
const T* __restrict lambda1_data,
|
||||
const T* __restrict lambda2_data,
|
||||
T* output_data) {
|
||||
compute_batch_box_cox_vec_sve128<T>(
|
||||
N, D, self_data, lambda1_data, lambda2_data, output_data);
|
||||
}
|
||||
|
||||
// Vectorized version specializations for float and double
|
||||
template void compute_batch_box_cox__sve128<float>(
|
||||
std::size_t N,
|
||||
std::size_t D,
|
||||
const float* self_data,
|
||||
const float* __restrict lambda1_data,
|
||||
const float* __restrict lambda2_data,
|
||||
float* output_data);
|
||||
|
||||
} // namespace caffe2::details
|
||||
|
||||
#endif // __aarch64__ && __ARM_FEATURE_SVE && CAFFE2_PERF_WITH_SVE128
|
||||
@ -107,12 +107,6 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a")
|
||||
endif()
|
||||
endif()
|
||||
# We will need to gate against CUDA version, because sm_103a is available on CUDA 12.9+
|
||||
if("${_arch}" STREQUAL "103a" AND CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
|
||||
if(_existing_arch_flags MATCHES ".*compute_100.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_103a,code=sm_103a")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "120a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_120.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
|
||||
@ -126,13 +120,13 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a;103a;120a")
|
||||
"89;90a;100a;120a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"90a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
|
||||
"90a;100a;103a")
|
||||
"90a;100a")
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 563 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 281 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 348 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 747 KiB |
@ -210,6 +210,10 @@ templates_path = [
|
||||
coverage_ignore_functions = [
|
||||
# torch
|
||||
"typename",
|
||||
# torch.cuda
|
||||
"check_error",
|
||||
"cudart",
|
||||
"is_bf16_supported",
|
||||
# torch.cuda._sanitizer
|
||||
"zip_arguments",
|
||||
"zip_by_key",
|
||||
@ -3176,8 +3180,6 @@ coverage_ignore_classes = [
|
||||
"WeakIdKeyDictionary",
|
||||
"WeakIdRef",
|
||||
"WeakTensorKeyDictionary",
|
||||
# torch.utils.debug_mode
|
||||
"DebugMode",
|
||||
]
|
||||
|
||||
# The suffix(es) of source filenames.
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
|
||||
StreamContext
|
||||
can_device_access_peer
|
||||
check_error
|
||||
current_blas_handle
|
||||
current_device
|
||||
current_stream
|
||||
@ -35,7 +34,6 @@
|
||||
init
|
||||
ipc_collect
|
||||
is_available
|
||||
is_bf16_supported
|
||||
is_initialized
|
||||
is_tf32_supported
|
||||
memory_usage
|
||||
|
||||
@ -22,7 +22,6 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
|
||||
device_count
|
||||
init
|
||||
is_available
|
||||
is_bf16_supported
|
||||
is_initialized
|
||||
memory_stats
|
||||
get_device_capability
|
||||
|
||||
@ -30,6 +30,9 @@
|
||||
.. autofunction:: create_mask
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: create_nested_block_mask
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: and_masks
|
||||
```
|
||||
```{eval-rst}
|
||||
|
||||
@ -3,6 +3,12 @@
|
||||
TorchInductor and AOTInductor Provenance Tracking
|
||||
=================================================
|
||||
|
||||
.. warning::
|
||||
This feature is a prototype under active development and there will be
|
||||
breaking change in future releases.
|
||||
The current compatibility of this tool is limited to the latest nightly build of PyTorch.
|
||||
|
||||
|
||||
This section describes how to use the provenance tracking feature for TorchInductor and AOTInductor in ``tlparse``.
|
||||
Provenance tracking helps you visualize the relationships between the input GraphModule to (AOT)Inductor and the optimized code generated. This feature allows you to trace how your original operations are transformed during compilation.
|
||||
|
||||
@ -31,7 +37,7 @@ Follow these steps to enable and use provenance tracking in your PyTorch project
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
TORCH_TRACE=~/my_trace_log_dir INDUCTOR_PROVENANCE=1 python your_program.py
|
||||
TORCH_TRACE=~/my_trace_log_dir TORCH_LOGS="+inductor" TORCH_COMPILE_DEBUG=1 python your_program.py
|
||||
|
||||
This will generate a log file in ``~/my_trace_log_dir``. The log file will be used by tlparse to generate the provenance tracking highlighter.
|
||||
3. Run ``tlparse`` on the log with ``--inductor-provenance`` flag. For example:
|
||||
@ -56,24 +62,6 @@ For a demo, see: https://github.com/pytorch/tlparse/pull/93
|
||||
.. image:: _static/img/inductor_provenance/index.png
|
||||
|
||||
|
||||
Source code corresponding to each Inductor kernel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
With ``INDUCTOR_PROVENANCE=1``, you can also view the source code corresponding to each Inductor kernel in tlparse. To access it, click the "readable_html" link next to "inductor_provenance_tracking_kernel_stack_traces.json" in the tlparse output.
|
||||
|
||||
.. image:: _static/img/inductor_provenance/index_2.png
|
||||
|
||||
|
||||
Below are some example screenshots. The ``:1`` and ``:467`` suffixes at the end of the kernel names are used to distinguish different calls to the same kernel. We refer to these suffixes as debug handles.
|
||||
|
||||
.. image:: _static/img/inductor_provenance/kernel_source_1.png
|
||||
.. image:: _static/img/inductor_provenance/kernel_source_2.png
|
||||
|
||||
You can also find the debug handle in the comments within the kernel source code.
|
||||
|
||||
.. image:: _static/img/inductor_provenance/kernel_source_3.png
|
||||
|
||||
|
||||
See Also
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@ -78,7 +78,6 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.data.graph
|
||||
.. py:module:: torch.utils.data.graph_settings
|
||||
.. py:module:: torch.utils.data.sampler
|
||||
.. py:module:: torch.utils.debug_mode
|
||||
.. py:module:: torch.utils.dlpack
|
||||
.. py:module:: torch.utils.file_baton
|
||||
.. py:module:: torch.utils.flop_counter
|
||||
|
||||
@ -12,7 +12,6 @@
|
||||
:nosignatures:
|
||||
|
||||
StreamContext
|
||||
can_device_access_peer
|
||||
current_device
|
||||
current_stream
|
||||
device
|
||||
@ -26,7 +25,6 @@
|
||||
get_stream_from_external
|
||||
init
|
||||
is_available
|
||||
is_bf16_supported
|
||||
is_initialized
|
||||
set_device
|
||||
set_stream
|
||||
|
||||
@ -1187,7 +1187,8 @@ int64_t _Tensor_ndim(mpy::handle h) {
|
||||
mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
|
||||
// fast case: tensor is live in python
|
||||
std::optional<PyObject*> mb_obj =
|
||||
t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj();
|
||||
t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
if (mb_obj.has_value() &&
|
||||
!t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
|
||||
return *mb_obj;
|
||||
|
||||
@ -879,15 +879,12 @@ void test_cuda_alloc_test() {
|
||||
if (cudaStatus != cudaSuccess || device_idx == -1) {
|
||||
throw std::runtime_error("cudaGetDevice failed!");
|
||||
}
|
||||
|
||||
c10::cuda::CUDACachingAllocator::emptyCache();
|
||||
c10::cuda::CUDACachingAllocator::DeviceStats stats =
|
||||
c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
|
||||
size_t initTorchActive = stats.allocated_bytes[0].current;
|
||||
size_t initTorchActive = stats.active_bytes[0].current;
|
||||
auto runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
|
||||
model_so_path);
|
||||
stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
|
||||
size_t torchActive = stats.allocated_bytes[0].current;
|
||||
size_t torchActive = stats.active_bytes[0].current;
|
||||
|
||||
ASSERT_EQ(initTorchActive + DATASIZE, torchActive);
|
||||
|
||||
@ -1116,7 +1113,8 @@ TEST(AotInductorTest, MultiStreamTestCuda) {
|
||||
test_multi_cuda_streams("cuda");
|
||||
}
|
||||
|
||||
TEST(AotInductorTest, CudaAllocTestCuda) {
|
||||
// TODO: ENABLE CUDACachingAllocator Test
|
||||
TEST(DISABLED_AotInductorTest, CudaAllocTestCuda) {
|
||||
test_cuda_alloc_test();
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -1490,8 +1490,8 @@ class TestFullyShardWorldSize1(FSDPTest):
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_train_parity_single_worldsize1(self):
|
||||
"""
|
||||
Tests train parity with DDP for a single FSDP group
|
||||
when sharding parameters on dim-0.
|
||||
Tests train parity with DDP for a single FSDP group when sharding
|
||||
parameters on dim-0.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
@ -1539,7 +1539,9 @@ class TestFullyShardWorldSize1(FSDPTest):
|
||||
losses.append(model(*inp).sum())
|
||||
losses[-1].backward()
|
||||
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
# Before there was 1 all-gather and 1 reduce-scatter
|
||||
# Now therre is 1 reduce-scatter
|
||||
self.assertEqual(comm_mode.get_total_counts(), 1)
|
||||
optim.step()
|
||||
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
@ -294,11 +294,11 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||
with CommDebugMode() as bwd_comm_mode:
|
||||
loss.backward()
|
||||
bwd_comm_counts = bwd_comm_mode.get_comm_counts()
|
||||
self.assertEqual(len(bwd_comm_counts), 1)
|
||||
self.assertEqual(len(bwd_comm_counts), 2)
|
||||
# First MLP's input gradient does not need to be all-reduced
|
||||
self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
|
||||
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], 0)
|
||||
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], 0)
|
||||
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
|
||||
ref_loss.backward()
|
||||
|
||||
optim.step()
|
||||
|
||||
@ -1,173 +0,0 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.replicate_with_fsdp import replicate
|
||||
from torch.distributed.fsdp import FSDPModule
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, get_devtype, MLP
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
c10d_ops = torch.ops.c10d
|
||||
funcol = torch.ops.c10d_functional
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestReplicateForwardInputs(FSDPTestMultiThread):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_root_move_forward_input_to_device(self):
|
||||
device = torch.device(device_type.type, 0)
|
||||
|
||||
class ParamlessModule(nn.Module):
|
||||
def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]):
|
||||
# Check that Replicate moved the inputs to GPU, including recursing
|
||||
# into the tuple data structure
|
||||
assert x.device == device, f"Expects {device} but got {x.device}"
|
||||
assert ys[0].device == device, (
|
||||
f"Expects {device} but got {ys[0].device}"
|
||||
)
|
||||
assert ys[1].device == device, (
|
||||
f"Expects {device} but got {ys[1].device}"
|
||||
)
|
||||
y = ys[0] + ys[1]
|
||||
return x + y + 1
|
||||
|
||||
model = ParamlessModule().to(device)
|
||||
replicate(model).to(device)
|
||||
x = torch.randn((3,))
|
||||
ys = (torch.randn((3,)), torch.randn((3,)))
|
||||
self.assertEqual(x.device, torch.device("cpu"))
|
||||
self.assertEqual(ys[0].device, torch.device("cpu"))
|
||||
self.assertEqual(ys[1].device, torch.device("cpu"))
|
||||
model(x, ys)
|
||||
|
||||
|
||||
class TestReplicateRegisteredParams(FSDPTestMultiThread):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_param_registration_after_forward(self):
|
||||
"""Tests the parameter registration after forward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = MLP(3, device)
|
||||
# Since seed is per process, not per thread, we broadcast to ensure
|
||||
# the same parameters across ranks
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
if reshard_after_forward:
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
else:
|
||||
self._assert_tensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
non_root_params = list(model[0].in_proj.parameters()) + list(
|
||||
model[0].out_proj.parameters()
|
||||
)
|
||||
root_params = list(set(model.parameters()) - set(non_root_params))
|
||||
if reshard_after_forward is None:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
elif reshard_after_forward:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_dtensor_params(root_params)
|
||||
else:
|
||||
self._assert_tensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_param_registration_after_backward(self):
|
||||
"""Tests the parameter registration after backward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
|
||||
# need to iterate over the list multiple times
|
||||
params = list(params)
|
||||
self.assertGreater(len(params), 0)
|
||||
for param in params:
|
||||
self.assertNotIsInstance(param, DTensor)
|
||||
self.assertIsInstance(param, torch.Tensor)
|
||||
|
||||
def _assert_dtensor_params(self, params: Iterable[nn.Parameter]):
|
||||
params = list(params)
|
||||
self.assertGreater(len(params), 0)
|
||||
for param in params:
|
||||
self.assertIsInstance(param, DTensor)
|
||||
|
||||
def _assert_same_params(
|
||||
self, params: Iterable[nn.Parameter], ref_params: Iterable[nn.Parameter]
|
||||
):
|
||||
params, ref_params = list(params), list(ref_params)
|
||||
self.assertEqual(len(params), len(ref_params))
|
||||
for param, ref_param in zip(params, ref_params):
|
||||
if isinstance(param, DTensor):
|
||||
param = param.full_tensor()
|
||||
self.assertEqual(param.shape, ref_param.shape)
|
||||
self.assertEqual(param, ref_param)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -47,14 +47,11 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestCoalesce(TestCase):
|
||||
def helper_test_coalesce(self, layout, coalesced_layout=None):
|
||||
def helper_test_coalesce(self, layout):
|
||||
layoutR = coalesce(layout)
|
||||
|
||||
_LOGGER.debug(f"{layout} => {layoutR}")
|
||||
|
||||
if coalesced_layout:
|
||||
self.assertEqual(coalesced_layout.shape, layoutR.shape)
|
||||
self.assertEqual(coalesced_layout.stride, layoutR.stride)
|
||||
self.assertEqual(size(layoutR), size(layout))
|
||||
|
||||
for i in range(size(layout)):
|
||||
@ -85,17 +82,11 @@ class TestCoalesce(TestCase):
|
||||
layout = Layout((2, (4, 6)))
|
||||
self.helper_test_coalesce(layout)
|
||||
|
||||
layout = Layout((1, 2), (8, 1))
|
||||
coalesced_layout = Layout(2, 1)
|
||||
self.helper_test_coalesce(layout, coalesced_layout)
|
||||
|
||||
layout = Layout((2, 4), (4, 1))
|
||||
coalesced_layout = Layout(8, 1)
|
||||
self.helper_test_coalesce(layout, coalesced_layout)
|
||||
self.helper_test_coalesce(layout)
|
||||
|
||||
layout = Layout((2, 4, 6), (24, 6, 1))
|
||||
coalesced_layout = Layout(48, 1)
|
||||
self.helper_test_coalesce(layout, coalesced_layout)
|
||||
self.helper_test_coalesce(layout)
|
||||
|
||||
layout = Layout((2, 1, 3), (2, 4, 4))
|
||||
self.helper_test_coalesce(layout)
|
||||
@ -103,10 +94,6 @@ class TestCoalesce(TestCase):
|
||||
layout = Layout(((2, 2), (2, 2)), ((1, 4), (8, 32)))
|
||||
self.helper_test_coalesce(layout)
|
||||
|
||||
layout = Layout(((2, 2), (2, 2)), ((32, 8), (4, 1)))
|
||||
coalesced_layout = Layout((2, 4, 2), (32, 4, 1))
|
||||
self.helper_test_coalesce(layout, coalesced_layout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -208,26 +208,11 @@ class TestComposition(TestCase):
|
||||
layoutB = Layout((6), (1))
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
# Pre-coalesced RHS
|
||||
layoutA = Layout((8, 6, 4), (7, 4, 1))
|
||||
layoutB = Layout((6), (1))
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
# Case when not meet stride divisibility condition
|
||||
with self.assertRaises(AssertionError):
|
||||
layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7))
|
||||
layoutB = Layout(6, 12)
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
# Mid-layout truncation
|
||||
layoutA = Layout((10, 8, 6, 4), (7, 5, 3, 2))
|
||||
layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7))
|
||||
layoutB = Layout(6, 12)
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
layoutA = Layout((4,), (3,))
|
||||
layoutB = Layout((6,), (2,))
|
||||
self.helper_test_composition(layoutA, layoutB)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -67,159 +67,20 @@ class TestIntTuple(TestCase):
|
||||
|
||||
self.assertEqual(shape_div((6, (3, 4)), 36), (1, (1, 2)))
|
||||
|
||||
def test_suffix_product(self):
|
||||
self.assertEqual(suffix_product(2), 1)
|
||||
def test_prefix_product(self):
|
||||
self.assertEqual(prefix_product(2), 1)
|
||||
|
||||
self.assertEqual(suffix_product((3, 2)), (2, 1))
|
||||
self.assertEqual(prefix_product((3, 2)), (1, 3))
|
||||
|
||||
self.assertEqual(suffix_product((3, 2, 4)), (8, 4, 1))
|
||||
self.assertEqual(prefix_product((3, 2, 4)), (1, 3, 6))
|
||||
|
||||
self.assertEqual(suffix_product(((2, 3), 4)), ((12, 4), 1))
|
||||
self.assertEqual(prefix_product(((2, 3), 4)), ((1, 2), 6))
|
||||
|
||||
self.assertEqual(
|
||||
suffix_product(((2, 3), (2, 1, 2), (5, 2, 1))),
|
||||
((120, 40), (20, 20, 10), (2, 1, 1)),
|
||||
prefix_product(((2, 3), (2, 1, 2), (5, 2, 1))),
|
||||
((1, 2), (6, 12, 12), (24, 120, 240)),
|
||||
)
|
||||
|
||||
def test_crd2idx_basic(self):
|
||||
# Test basic int/int case
|
||||
self.assertEqual(crd2idx(2, 5, 1), 2)
|
||||
self.assertEqual(crd2idx(0, 5, 1), 0)
|
||||
self.assertEqual(crd2idx(4, 5, 1), 4)
|
||||
|
||||
# Test with custom stride
|
||||
self.assertEqual(crd2idx(2, 5, 3), 6)
|
||||
self.assertEqual(crd2idx(1, 5, 3), 3)
|
||||
|
||||
def test_crd2idx_tuple(self):
|
||||
# Test tuple coordinates with default stride
|
||||
self.assertEqual(crd2idx((1, 2), (3, 4)), 6) # 1*4 + 2*1 = 6
|
||||
self.assertEqual(crd2idx((0, 0), (3, 4)), 0)
|
||||
self.assertEqual(crd2idx((2, 3), (3, 4)), 11) # 2*4 + 3*1 = 11
|
||||
|
||||
# Test with custom stride
|
||||
self.assertEqual(crd2idx((1, 2), (3, 4), (8, 2)), 12) # 1*8 + 2*2 = 12
|
||||
|
||||
# Test 3D case
|
||||
self.assertEqual(crd2idx((1, 0, 2), (2, 3, 4)), 14) # 1*12 + 0*4 + 2*1 = 14
|
||||
|
||||
def test_crd2idx_none(self):
|
||||
# Test None coordinate (should default to 0)
|
||||
self.assertEqual(crd2idx(None, 5), 0)
|
||||
self.assertEqual(crd2idx(None, (3, 4)), 0)
|
||||
|
||||
def test_crd2idx_int_with_tuple_shape(self):
|
||||
# Test single integer coordinate with multi-dimensional shape and stride
|
||||
# When crd is int and shape is tuple, it converts the int to multi-dim coordinate first
|
||||
self.assertEqual(crd2idx(0, (2, 2), (2, 1)), 0) # 0 -> (0,0) -> 0*2 + 0*1 = 0
|
||||
self.assertEqual(crd2idx(1, (2, 2), (2, 1)), 1) # 1 -> (0,1) -> 0*2 + 1*1 = 1
|
||||
self.assertEqual(crd2idx(2, (2, 2), (2, 1)), 2) # 2 -> (1,0) -> 1*2 + 0*1 = 2
|
||||
self.assertEqual(crd2idx(3, (2, 2), (2, 1)), 3) # 3 -> (1,1) -> 1*2 + 1*1 = 3
|
||||
|
||||
# Test with non-trivial strides
|
||||
self.assertEqual(crd2idx(0, (2, 3), (6, 2)), 0) # 0 -> (0,0) -> 0*6 + 0*2 = 0
|
||||
self.assertEqual(crd2idx(1, (2, 3), (6, 2)), 2) # 1 -> (0,1) -> 0*6 + 1*2 = 2
|
||||
self.assertEqual(crd2idx(2, (2, 3), (6, 2)), 4) # 2 -> (0,2) -> 0*6 + 2*2 = 4
|
||||
self.assertEqual(crd2idx(3, (2, 3), (6, 2)), 6) # 3 -> (1,0) -> 1*6 + 0*2 = 6
|
||||
self.assertEqual(crd2idx(4, (2, 3), (6, 2)), 8) # 4 -> (1,1) -> 1*6 + 1*2 = 8
|
||||
self.assertEqual(crd2idx(5, (2, 3), (6, 2)), 10) # 5 -> (1,2) -> 1*6 + 2*2 = 10
|
||||
|
||||
# Test with larger strides
|
||||
self.assertEqual(crd2idx(0, (3, 2), (10, 5)), 0) # 0 -> (0,0) -> 0*10 + 0*5 = 0
|
||||
self.assertEqual(crd2idx(1, (3, 2), (10, 5)), 5) # 1 -> (0,1) -> 0*10 + 1*5 = 5
|
||||
self.assertEqual(
|
||||
crd2idx(2, (3, 2), (10, 5)), 10
|
||||
) # 2 -> (1,0) -> 1*10 + 0*5 = 10
|
||||
self.assertEqual(
|
||||
crd2idx(3, (3, 2), (10, 5)), 15
|
||||
) # 3 -> (1,1) -> 1*10 + 1*5 = 15
|
||||
self.assertEqual(
|
||||
crd2idx(4, (3, 2), (10, 5)), 20
|
||||
) # 4 -> (2,0) -> 2*10 + 0*5 = 20
|
||||
self.assertEqual(
|
||||
crd2idx(5, (3, 2), (10, 5)), 25
|
||||
) # 5 -> (2,1) -> 2*10 + 1*5 = 25
|
||||
|
||||
# Test with 3D shape and various strides
|
||||
self.assertEqual(
|
||||
crd2idx(0, (2, 2, 2), (8, 4, 2)), 0
|
||||
) # 0 -> (0,0,0) -> 0*8 + 0*4 + 0*2 = 0
|
||||
self.assertEqual(
|
||||
crd2idx(1, (2, 2, 2), (8, 4, 2)), 2
|
||||
) # 1 -> (0,0,1) -> 0*8 + 0*4 + 1*2 = 2
|
||||
self.assertEqual(
|
||||
crd2idx(2, (2, 2, 2), (8, 4, 2)), 4
|
||||
) # 2 -> (0,1,0) -> 0*8 + 1*4 + 0*2 = 4
|
||||
self.assertEqual(
|
||||
crd2idx(3, (2, 2, 2), (8, 4, 2)), 6
|
||||
) # 3 -> (0,1,1) -> 0*8 + 1*4 + 1*2 = 6
|
||||
self.assertEqual(
|
||||
crd2idx(4, (2, 2, 2), (8, 4, 2)), 8
|
||||
) # 4 -> (1,0,0) -> 1*8 + 0*4 + 0*2 = 8
|
||||
self.assertEqual(
|
||||
crd2idx(7, (2, 2, 2), (8, 4, 2)), 14
|
||||
) # 7 -> (1,1,1) -> 1*8 + 1*4 + 1*2 = 14
|
||||
|
||||
self.assertEqual(
|
||||
crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8
|
||||
) # 4 -> (1,0,0) -> 1*8 = 8
|
||||
|
||||
def test_idx2crd_basic(self):
|
||||
# Test basic int/int case
|
||||
self.assertEqual(idx2crd(2, 5, 1), 2)
|
||||
self.assertEqual(idx2crd(0, 5, 1), 0)
|
||||
self.assertEqual(idx2crd(4, 5, 1), 4)
|
||||
|
||||
# Test with custom stride
|
||||
self.assertEqual(idx2crd(6, 5, 3), 2) # (6 // 3) % 5 = 2
|
||||
self.assertEqual(idx2crd(3, 5, 3), 1) # (3 // 3) % 5 = 1
|
||||
|
||||
def test_idx2crd_tuple(self):
|
||||
# Test tuple shape with default stride
|
||||
self.assertEqual(idx2crd(6, (3, 4)), (1, 2)) # 6 -> (1, 2)
|
||||
self.assertEqual(idx2crd(0, (3, 4)), (0, 0))
|
||||
self.assertEqual(idx2crd(11, (3, 4)), (2, 3))
|
||||
|
||||
# Test 3D case
|
||||
self.assertEqual(idx2crd(14, (2, 3, 4)), (1, 0, 2))
|
||||
|
||||
def test_crd2idx_idx2crd_roundtrip(self):
|
||||
# Test that crd2idx and idx2crd are inverse operations
|
||||
shapes = [
|
||||
5,
|
||||
(3, 4),
|
||||
(2, 3, 4),
|
||||
(2, 2, 2, 2),
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
size = product(shape)
|
||||
for idx in range(size):
|
||||
crd = idx2crd(idx, shape)
|
||||
recovered_idx = crd2idx(crd, shape)
|
||||
self.assertEqual(
|
||||
recovered_idx, idx, f"Failed roundtrip for shape {shape}, idx {idx}"
|
||||
)
|
||||
|
||||
def test_idx2crd_crd2idx_roundtrip(self):
|
||||
# Test roundtrip starting from coordinates
|
||||
test_cases = [
|
||||
(0, 5),
|
||||
(4, 5),
|
||||
((0, 0), (3, 4)),
|
||||
((1, 2), (3, 4)),
|
||||
((2, 3), (3, 4)),
|
||||
((0, 0, 0), (2, 3, 4)),
|
||||
((1, 2, 3), (2, 3, 4)),
|
||||
]
|
||||
|
||||
for crd, shape in test_cases:
|
||||
idx = crd2idx(crd, shape)
|
||||
recovered_crd = idx2crd(idx, shape)
|
||||
self.assertEqual(
|
||||
recovered_crd, crd, f"Failed roundtrip for crd {crd}, shape {shape}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
# Owner(s): ["module: unknown"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Owner(s): ["module: fsdp"]
|
||||
# Owner(s): ["module: unknown"]
|
||||
import functools
|
||||
import gc
|
||||
from typing import Union
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
# Owner(s): ["module: unknown"]
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
from copy import copy
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
# Owner(s): ["module: unknown"]
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Union
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
# Owner(s): ["module: unknown"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
# Owner(s): ["module: unknown"]
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
|
||||
@ -1,21 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
# Owner(s): ["oncall: r2p"]
|
||||
|
||||
import functools
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import json
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import functools
|
||||
import os
|
||||
import signal
|
||||
import unittest
|
||||
import uuid
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Any
|
||||
from unittest.mock import call, MagicMock, patch
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
|
||||
@ -29,7 +29,6 @@ from torch.distributed.elastic.agent.server.api import (
|
||||
WorkerSpec,
|
||||
WorkerState,
|
||||
)
|
||||
from torch.distributed.elastic.events import EventSource
|
||||
from torch.distributed.elastic.multiprocessing import SignalException
|
||||
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
|
||||
from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
|
||||
@ -158,243 +157,6 @@ def monres(state: WorkerState):
|
||||
return RunResult(state=state)
|
||||
|
||||
|
||||
class RecordWorkerEventsTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.spec = MagicMock()
|
||||
self.spec.role = "test_role"
|
||||
self.spec.get_entrypoint_name.return_value = "test_entrypoint"
|
||||
self.spec.rdzv_handler.get_run_id.return_value = "test_run_id"
|
||||
self.spec.rdzv_handler.get_backend.return_value = "test_backend"
|
||||
self.spec.max_restarts = 3
|
||||
|
||||
self.agent = TestAgent(self.spec)
|
||||
|
||||
# Create a mock worker spec and agent
|
||||
self.agent._worker_group = MagicMock()
|
||||
self.agent._worker_group.spec = MagicMock()
|
||||
self.agent._worker_group.spec.event_log_handler = "test_handler"
|
||||
|
||||
# Setup worker group
|
||||
self.worker_group = WorkerGroup(self.spec)
|
||||
self.worker_group.group_world_size = 2
|
||||
self.worker_group.group_rank = 1
|
||||
self.agent._worker_group = self.worker_group
|
||||
|
||||
# Create a test worker
|
||||
|
||||
self.workers = [
|
||||
Worker(
|
||||
local_rank=0,
|
||||
global_rank=0,
|
||||
role_rank=0,
|
||||
world_size=2,
|
||||
role_world_size=2,
|
||||
),
|
||||
Worker(
|
||||
local_rank=1,
|
||||
global_rank=1,
|
||||
role_rank=1,
|
||||
world_size=2,
|
||||
role_world_size=2,
|
||||
),
|
||||
]
|
||||
self.workers[0].id = 0
|
||||
self.workers[1].id = 1
|
||||
self.agent._worker_group.workers = self.workers
|
||||
|
||||
@patch("torch.distributed.elastic.agent.server.api.record")
|
||||
def test_record_worker_events_success(self, mock_record):
|
||||
# Create a RunResult with successful workers
|
||||
result = RunResult(
|
||||
state=WorkerState.SUCCEEDED,
|
||||
return_values={0: "result0", 1: "result1"},
|
||||
failures={},
|
||||
)
|
||||
|
||||
# Call the method under test
|
||||
self.agent._record_worker_events(result)
|
||||
|
||||
# Verify record was called twice (once for each worker)
|
||||
self.assertEqual(mock_record.call_count, 2)
|
||||
|
||||
# Check that both calls were for SUCCEEDED events
|
||||
for call_args in mock_record.call_args_list:
|
||||
event = call_args[0][0]
|
||||
|
||||
self.assertEqual(event.source, EventSource.WORKER)
|
||||
self.assertEqual(event.metadata["state"], "SUCCEEDED")
|
||||
self.assertIsNone(event.metadata["raw_error"])
|
||||
md = json.loads(event.metadata["metadata"])
|
||||
self.assertEqual(md["exit_code"], [None])
|
||||
self.assertEqual(md["worker_pid"], [None])
|
||||
|
||||
@patch("torch.distributed.elastic.agent.server.api.record")
|
||||
def test_record_worker_events_failure(self, mock_record):
|
||||
# Create failures with error data
|
||||
failure0 = ProcessFailure(
|
||||
local_rank=0, pid=1000, exitcode=1, error_file="error0.json"
|
||||
)
|
||||
|
||||
# Create a RunResult with one failed worker and one terminated worker
|
||||
result = RunResult(
|
||||
state=WorkerState.FAILED,
|
||||
return_values={},
|
||||
failures={0: failure0}, # Only worker 0 has a specific failure
|
||||
)
|
||||
|
||||
# Call the method under test
|
||||
self.agent._record_worker_events(result)
|
||||
|
||||
# Verify record was called twice (once for each worker)
|
||||
self.assertEqual(mock_record.call_count, 2)
|
||||
|
||||
# Get the calls
|
||||
calls = mock_record.call_args_list
|
||||
|
||||
# Check first call for the failed worker (global_rank=0)
|
||||
failed_event = calls[0][0][0]
|
||||
self.assertEqual(failed_event.source, EventSource.WORKER)
|
||||
self.assertEqual(failed_event.metadata["state"], "FAILED")
|
||||
self.assertEqual(failed_event.metadata["global_rank"], 0)
|
||||
md = json.loads(failed_event.metadata["metadata"])
|
||||
self.assertEqual(failed_event.metadata["raw_error"], '{"message": "<NONE>"}')
|
||||
self.assertEqual(md["exit_code"], [1])
|
||||
self.assertEqual(md["worker_pid"], [1000])
|
||||
|
||||
# Check second call for the terminated worker (global_rank=1)
|
||||
terminated_event = calls[1][0][0]
|
||||
self.assertEqual(terminated_event.source, EventSource.WORKER)
|
||||
self.assertEqual(terminated_event.metadata["state"], "TERMINATED")
|
||||
self.assertEqual(terminated_event.metadata["global_rank"], 1)
|
||||
self.assertIsNone(terminated_event.metadata["raw_error"])
|
||||
md = json.loads(terminated_event.metadata["metadata"])
|
||||
self.assertEqual(md["exit_code"], [None])
|
||||
self.assertEqual(md["worker_pid"], [None])
|
||||
|
||||
|
||||
class ConstructEventTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create minimal spec and agent for testing
|
||||
self.spec = MagicMock()
|
||||
self.spec.role = "test_role"
|
||||
self.spec.get_entrypoint_name.return_value = "test_entrypoint"
|
||||
self.spec.rdzv_handler.get_run_id.return_value = "test_run_id"
|
||||
self.spec.rdzv_handler.get_backend.return_value = "test_backend"
|
||||
self.spec.max_restarts = 3
|
||||
|
||||
self.agent = TestAgent(self.spec)
|
||||
self.agent._remaining_restarts = 2
|
||||
self.agent._total_execution_time = 42
|
||||
|
||||
# Setup worker group
|
||||
self.worker_group = WorkerGroup(self.spec)
|
||||
self.worker_group.group_world_size = 2
|
||||
self.worker_group.group_rank = 1
|
||||
self.agent._worker_group = self.worker_group
|
||||
|
||||
# Create a test worker
|
||||
self.worker = Worker(
|
||||
local_rank=0, global_rank=5, role_rank=3, world_size=8, role_world_size=4
|
||||
)
|
||||
self.worker.id = 12345
|
||||
|
||||
def test_construct_event_agent_success(self):
|
||||
# Test constructing an agent success event
|
||||
event = self.agent._construct_event(state="SUCCEEDED", source=EventSource.AGENT)
|
||||
|
||||
# Verify basic event properties
|
||||
self.assertEqual(event.name, "torchelastic.worker.status.SUCCEEDED")
|
||||
self.assertEqual(event.source, EventSource.AGENT)
|
||||
|
||||
# Verify metadata
|
||||
metadata = event.metadata
|
||||
self.assertEqual(metadata["run_id"], "test_run_id")
|
||||
self.assertIsNone(metadata["global_rank"])
|
||||
self.assertEqual(metadata["group_rank"], 1)
|
||||
self.assertIsNone(metadata["worker_id"])
|
||||
self.assertEqual(metadata["role"], "test_role")
|
||||
self.assertEqual(metadata["state"], "SUCCEEDED")
|
||||
self.assertEqual(metadata["total_run_time"], 42)
|
||||
self.assertEqual(metadata["rdzv_backend"], "test_backend")
|
||||
self.assertIsNone(metadata["raw_error"])
|
||||
self.assertEqual(
|
||||
metadata["agent_restarts"], 1
|
||||
) # max_restarts - remaining_restarts
|
||||
self.assertIsNone(metadata["duration_ms"])
|
||||
|
||||
# Verify JSON metadata
|
||||
md_dict = json.loads(metadata["metadata"])
|
||||
self.assertEqual(md_dict["group_world_size"], 2)
|
||||
self.assertEqual(md_dict["entry_point"], "test_entrypoint")
|
||||
|
||||
def test_construct_event_worker_failure(self):
|
||||
# Test constructing a worker failure event with raw error
|
||||
raw_error = json.dumps(
|
||||
{"error_message": "Test error", "traceback": "stack trace"}
|
||||
)
|
||||
event = self.agent._construct_event(
|
||||
state="FAILED",
|
||||
source=EventSource.WORKER,
|
||||
worker=self.worker,
|
||||
raw_error=raw_error,
|
||||
exit_code=1,
|
||||
)
|
||||
|
||||
# Verify basic event properties
|
||||
self.assertEqual(event.name, "torchelastic.worker.status.FAILED")
|
||||
self.assertEqual(event.source, EventSource.WORKER)
|
||||
|
||||
# Verify metadata
|
||||
metadata = event.metadata
|
||||
self.assertEqual(metadata["run_id"], "test_run_id")
|
||||
self.assertEqual(metadata["global_rank"], 5)
|
||||
self.assertEqual(metadata["group_rank"], 1)
|
||||
self.assertEqual(metadata["worker_id"], "12345")
|
||||
self.assertEqual(metadata["role"], "test_role")
|
||||
self.assertEqual(metadata["state"], "FAILED")
|
||||
self.assertEqual(metadata["total_run_time"], 42)
|
||||
self.assertEqual(metadata["rdzv_backend"], "test_backend")
|
||||
self.assertEqual(metadata["raw_error"], raw_error)
|
||||
self.assertEqual(metadata["agent_restarts"], 1)
|
||||
|
||||
# Verify worker-specific metadata
|
||||
md_dict = json.loads(metadata["metadata"])
|
||||
self.assertEqual(md_dict["local_rank"], [0])
|
||||
self.assertEqual(md_dict["role_rank"], [3])
|
||||
self.assertEqual(md_dict["role_world_size"], [4])
|
||||
self.assertEqual(md_dict["exit_code"], [1])
|
||||
|
||||
def test_construct_event_with_duration(self):
|
||||
# Test constructing an event with duration_ms
|
||||
event = self.agent._construct_event(
|
||||
state="RENDEZVOUS", source=EventSource.AGENT, duration_ms=123.45
|
||||
)
|
||||
|
||||
# Verify duration is set correctly
|
||||
self.assertEqual(event.metadata["duration_ms"], 123.45)
|
||||
|
||||
def test_construct_event_worker_no_error(self):
|
||||
# Test constructing a worker event without error info
|
||||
event = self.agent._construct_event(
|
||||
state="HEALTHY", source=EventSource.WORKER, worker=self.worker
|
||||
)
|
||||
|
||||
# Verify error fields are None
|
||||
metadata = event.metadata
|
||||
self.assertIsNone(metadata["raw_error"])
|
||||
|
||||
# Check worker info is set
|
||||
self.assertEqual(metadata["global_rank"], 5)
|
||||
self.assertEqual(metadata["worker_id"], "12345")
|
||||
|
||||
# Check metadata JSON
|
||||
md_dict = json.loads(metadata["metadata"])
|
||||
self.assertEqual(md_dict["local_rank"], [0])
|
||||
self.assertEqual(md_dict["role_rank"], [3])
|
||||
self.assertEqual(md_dict["role_world_size"], [4])
|
||||
self.assertNotIn("exit_code", [None])
|
||||
|
||||
|
||||
class SimpleElasticAgentTest(unittest.TestCase):
|
||||
def _get_worker_spec(
|
||||
self,
|
||||
|
||||
@ -568,8 +568,9 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
)
|
||||
|
||||
results = pc.wait(period=0.1)
|
||||
|
||||
self.assertTrue(results.is_failed())
|
||||
self.assertEqual(2, len(results.failures))
|
||||
self.assertEqual(1, len(results.failures))
|
||||
|
||||
failure = results.failures[0]
|
||||
self.assertEqual(138, failure.exitcode)
|
||||
@ -582,13 +583,6 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||
self.assertTrue(pc._stderr_tail.stopped())
|
||||
self.assertTrue(pc._stdout_tail.stopped())
|
||||
|
||||
failure = results.failures[1]
|
||||
self.assertEqual(-15, failure.exitcode)
|
||||
self.assertEqual("SIGTERM", failure.signal_name())
|
||||
self.assertEqual("<NONE>", failure.error_file_data["message"])
|
||||
# Assert that the failure message contains expected substrings
|
||||
self.assertIn("Signal 15 (SIGTERM) received by PID", failure.message)
|
||||
|
||||
def test_binary_raises(self):
|
||||
pc = start_processes(
|
||||
name="echo",
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -24,6 +23,5 @@ if __name__ == "__main__":
|
||||
print(f"exit {exitcode} from {rank}", file=sys.stderr)
|
||||
sys.exit(exitcode)
|
||||
else:
|
||||
time.sleep(1000)
|
||||
print(f"{args.msg} stdout from {rank}")
|
||||
print(f"{args.msg} stderr from {rank}", file=sys.stderr)
|
||||
|
||||
@ -15,7 +15,6 @@ import uuid
|
||||
|
||||
import torch.distributed.elastic.timer as timer
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_ARM64,
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
@ -24,8 +23,8 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
# timer is not supported on these platforms
|
||||
if not (IS_WINDOWS or IS_MACOS or IS_ARM64):
|
||||
# timer is not supported on windows or macos
|
||||
if not (IS_WINDOWS or IS_MACOS):
|
||||
# func2 should time out
|
||||
def func2(n, file_path):
|
||||
if file_path is not None:
|
||||
|
||||
@ -14,7 +14,6 @@ import time
|
||||
import torch.distributed.elastic.timer as timer
|
||||
import torch.multiprocessing as torch_mp
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_ARM64,
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
@ -41,8 +40,8 @@ def _stuck_function(rank, mp_queue):
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
# timer is not supported on these platforms
|
||||
if not (IS_WINDOWS or IS_MACOS or IS_ARM64):
|
||||
# timer is not supported on macos or windows
|
||||
if not (IS_WINDOWS or IS_MACOS):
|
||||
|
||||
class LocalTimerExample(TestCase):
|
||||
"""
|
||||
|
||||
@ -15,7 +15,6 @@ import torch.distributed.elastic.timer as timer
|
||||
from torch.distributed.elastic.timer.api import TimerRequest
|
||||
from torch.distributed.elastic.timer.local_timer import MultiprocessingRequestQueue
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_ARM64,
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
@ -25,10 +24,8 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
# timer is not supported on these platforms
|
||||
INVALID_PLATFORMS = IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN or IS_ARM64
|
||||
|
||||
if not INVALID_PLATFORMS:
|
||||
# timer is not supported on windows or macos
|
||||
if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
|
||||
# func2 should time out
|
||||
def func2(n, mp_queue):
|
||||
if mp_queue is not None:
|
||||
@ -132,7 +129,8 @@ if not INVALID_PLATFORMS:
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
if not INVALID_PLATFORMS:
|
||||
# timer is not supported on windows or macos
|
||||
if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
|
||||
|
||||
class MultiprocessingRequestQueueTest(TestCase):
|
||||
def test_get(self):
|
||||
@ -199,7 +197,8 @@ if not INVALID_PLATFORMS:
|
||||
self.assertLessEqual(n / 2, len(requests))
|
||||
|
||||
|
||||
if not INVALID_PLATFORMS:
|
||||
# timer is not supported on windows or macos
|
||||
if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
|
||||
|
||||
class LocalTimerServerTest(TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@ -41,7 +41,6 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
@ -58,8 +57,6 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
||||
"""Tests multiple parameter groups."""
|
||||
@ -161,7 +158,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
||||
device_init_mode == DEVICEInitMode.DEVICE_AFTER
|
||||
and not fsdp_model.cpu_offload.offload_params
|
||||
):
|
||||
fsdp_model = fsdp_model.to(device=device_type)
|
||||
fsdp_model = fsdp_model.cuda()
|
||||
return fsdp_model, fsdp_optim
|
||||
|
||||
def _check_train_parity(
|
||||
@ -174,7 +171,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
||||
num_iters: int = 10,
|
||||
):
|
||||
"""Checks training parity between DDP and FSDP."""
|
||||
device = torch.device(device_type)
|
||||
device = torch.device("cuda")
|
||||
for i in range(num_iters):
|
||||
iter_losses = []
|
||||
for model, optim in ((ddp_model, ddp_optim), (fsdp_model, fsdp_optim)):
|
||||
@ -265,7 +262,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
for _ in range(10):
|
||||
losses = []
|
||||
inp = ref_model.get_input(torch.device(device_type))
|
||||
inp = ref_model.get_input(torch.device("cuda"))
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad()
|
||||
loss = _model(*inp).sum()
|
||||
@ -473,7 +470,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
||||
):
|
||||
ddp_optims.append(optim_ctor(ddp_param_group["params"]))
|
||||
fsdp_optims.append(optim_ctor(fsdp_param_group["params"]))
|
||||
device = torch.device(device_type)
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Check that there exists a `FlatParameter` that has both a weight and
|
||||
# a bias in this rank's shard
|
||||
@ -646,7 +643,7 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
|
||||
fsdp_model_orig_params,
|
||||
optim_orig_params,
|
||||
) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)
|
||||
device = torch.device(device_type)
|
||||
device = torch.device("cuda")
|
||||
for _ in range(3):
|
||||
inp1 = fsdp_model.get_input(device)
|
||||
_inp2 = fsdp_model.get_input(device)
|
||||
@ -704,7 +701,7 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
|
||||
fsdp_model_orig_params,
|
||||
optim_orig_params,
|
||||
) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)
|
||||
device = torch.device(device_type)
|
||||
device = torch.device("cuda")
|
||||
for _ in range(3):
|
||||
optim.zero_grad()
|
||||
optim_orig_params.zero_grad()
|
||||
@ -831,9 +828,9 @@ class TestFSDPUseOrigParamsParamAccess(FSDPTest):
|
||||
p1 = p1.flatten()
|
||||
torch.testing.assert_close(p1, p2)
|
||||
|
||||
ddp_model = DDP(Model().to(device=device_type), device_ids=[self.rank])
|
||||
ddp_model = DDP(Model().cuda(), device_ids=[self.rank])
|
||||
fsdp_model = FSDP(
|
||||
Model().to(device=device_type),
|
||||
Model().cuda(),
|
||||
sharding_strategy=sharding_strategy,
|
||||
auto_wrap_policy=always_wrap_policy,
|
||||
use_orig_params=True,
|
||||
@ -841,7 +838,7 @@ class TestFSDPUseOrigParamsParamAccess(FSDPTest):
|
||||
LR = 1e-2
|
||||
ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
|
||||
device = torch.device(device_type)
|
||||
device = torch.device("cuda")
|
||||
|
||||
inp = fsdp_model.get_input(device)
|
||||
ddp_out = ddp_model(*inp)
|
||||
@ -916,11 +913,11 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
|
||||
|
||||
# Check that the writeback propagates
|
||||
ddp_model = DDP(
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
fsdp_model = FSDP(
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
|
||||
use_orig_params=True,
|
||||
)
|
||||
ddp = ddp_model.module # for brevity
|
||||
@ -969,11 +966,11 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
|
||||
return None if set_to_none else torch.ones_like(param) * 2
|
||||
|
||||
ddp_model = DDP(
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
fsdp_model = FSDP(
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
|
||||
use_orig_params=True,
|
||||
)
|
||||
LR = 1e-2
|
||||
@ -984,7 +981,7 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
|
||||
|
||||
# Generate an initial gradient
|
||||
inp = fsdp_model.get_input(torch.device(device_type))
|
||||
inp = fsdp_model.get_input(torch.device("cuda"))
|
||||
ddp_out = ddp_model(*inp)
|
||||
fsdp_out = fsdp_model(*inp)
|
||||
ddp_out.sum().backward()
|
||||
@ -1014,7 +1011,7 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
|
||||
self._check_param_parity(ddp_model, fsdp_model) # triggers a writeback
|
||||
|
||||
# Intentionally do not zero the gradient to check writeback
|
||||
inp = fsdp_model.get_input(torch.device(device_type))
|
||||
inp = fsdp_model.get_input(torch.device("cuda"))
|
||||
ddp_out = ddp_model(*inp)
|
||||
fsdp_out = fsdp_model(*inp)
|
||||
ddp_out.sum().backward()
|
||||
@ -1026,7 +1023,7 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_writeback_shape_mismatch(self):
|
||||
fsdp_model = FSDP(
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
|
||||
use_orig_params=True,
|
||||
)
|
||||
# Check that writing back with mismatched shape errors
|
||||
@ -1076,9 +1073,9 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
|
||||
# Test changing the parameter storage to no longer be a view into the
|
||||
# flat parameter
|
||||
fsdp_model = fsdp_wrapper(
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type))
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
|
||||
)
|
||||
inp = fsdp_model.get_input(torch.device(device_type))
|
||||
inp = fsdp_model.get_input(torch.device("cuda"))
|
||||
loss = fsdp_model(*inp).sum()
|
||||
fsdp_model.lin1.weight.data = fsdp_model.lin1.weight.clone()
|
||||
assert_msg = (
|
||||
@ -1089,9 +1086,9 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
|
||||
|
||||
# Test changing the parameter variable itself
|
||||
fsdp_model = fsdp_wrapper(
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type))
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
|
||||
)
|
||||
inp = fsdp_model.get_input(torch.device(device_type))
|
||||
inp = fsdp_model.get_input(torch.device("cuda"))
|
||||
loss = fsdp_model(*inp).sum()
|
||||
fsdp_model.lin1._fsdp_wrapped_module.weight = nn.Parameter(
|
||||
fsdp_model.lin1.weight.clone()
|
||||
@ -1125,10 +1122,9 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
|
||||
|
||||
# Train forward -> full-precision unshard -> train forward
|
||||
fsdp_model = FSDP(
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
|
||||
**fsdp_kwargs,
|
||||
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), **fsdp_kwargs
|
||||
)
|
||||
inp = fsdp_model.get_input(torch.device(device_type))
|
||||
inp = fsdp_model.get_input(torch.device("cuda"))
|
||||
fsdp_model(*inp)
|
||||
with FSDP.summon_full_params(fsdp_model):
|
||||
...
|
||||
@ -1187,13 +1183,13 @@ class TestFSDPUseOrigParamsFQNs(FSDPTest):
|
||||
assert_equal_fn(params[1].shape, param_shapes[1])
|
||||
return self.lin(x)
|
||||
|
||||
model = Model().to(device=device_type)
|
||||
model = Model().cuda()
|
||||
# Save the *unsharded* original parameter shapes and check the shapes
|
||||
# match in the forward pass
|
||||
param_shapes[0] = model.lin.weight.shape
|
||||
param_shapes[1] = model.lin.bias.shape
|
||||
fsdp_model = FSDP(model, use_orig_params=True)
|
||||
inp = torch.randn((2, 5), device=torch.device(device_type))
|
||||
inp = torch.randn((2, 5), device=torch.device("cuda"))
|
||||
fsdp_model(inp)
|
||||
|
||||
|
||||
@ -1220,7 +1216,7 @@ class TestFSDPUseOrigParamsNoSync(FSDPTest):
|
||||
)
|
||||
|
||||
def _test_no_sync_correctness(self, sharding_strategy: ShardingStrategy):
|
||||
model = nn.Linear(7, 1, bias=False, device=device_type)
|
||||
model = nn.Linear(7, 1, bias=False, device="cuda")
|
||||
fsdp_kwargs = {
|
||||
"sharding_strategy": sharding_strategy,
|
||||
}
|
||||
@ -1270,8 +1266,8 @@ class TestFSDPUseOrigParamsNoSync(FSDPTest):
|
||||
orig_param.grad,
|
||||
)
|
||||
|
||||
inp = torch.randn((2, 7), device=device_type)
|
||||
grad = torch.randn((2, 1), device=device_type)
|
||||
inp = torch.randn((2, 7), device="cuda")
|
||||
grad = torch.randn((2, 1), device="cuda")
|
||||
|
||||
# Compute some reference gradients using one forward/backward
|
||||
out_use_flat_params = model_use_flat_params(inp)
|
||||
@ -1337,7 +1333,7 @@ class TestFSDPUseOrigParamsNoSync(FSDPTest):
|
||||
)
|
||||
|
||||
def _test_no_sync_mixed_precision(self, sharding_strategy: ShardingStrategy):
|
||||
model = nn.Linear(3, 3, device=device_type)
|
||||
model = nn.Linear(3, 3, device="cuda")
|
||||
mixed_precision = MixedPrecision(
|
||||
param_dtype=torch.float16,
|
||||
reduce_dtype=torch.float32,
|
||||
@ -1348,7 +1344,7 @@ class TestFSDPUseOrigParamsNoSync(FSDPTest):
|
||||
"use_orig_params": True,
|
||||
}
|
||||
fsdp_model = FSDP(model, **fsdp_kwargs)
|
||||
inp = torch.randn((2, 3), device=device_type)
|
||||
inp = torch.randn((2, 3), device="cuda")
|
||||
with fsdp_model.no_sync():
|
||||
# For each of these `no_sync()` backward passes, check that the
|
||||
# gradients are in the low precision parameter dtype (FP16)
|
||||
@ -1372,8 +1368,8 @@ class TestFSDPUseOrigParamsInit(FSDPTest):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_non_uniform_requires_grad(self):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(3, 3, device=device_type),
|
||||
nn.Linear(3, 3, device=device_type),
|
||||
nn.Linear(3, 3, device="cuda"),
|
||||
nn.Linear(3, 3, device="cuda"),
|
||||
)
|
||||
# Freeze biases only and flatten both weights and biases into the same
|
||||
# `FlatParameter` to exercise non-uniform `requires_grad`
|
||||
@ -1396,10 +1392,10 @@ class TestMultiTensorApply(TestCase):
|
||||
# Check that this does not segfault
|
||||
torch._foreach_mul_(size0_tensors, 0.1)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda and no xpu")
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
def test_multi_tensor_apply_size0_tensors_cuda(self):
|
||||
size0_tensors = [
|
||||
torch.empty(0, device=device_type) for _ in range(NUM_SIZE0_TENSORS)
|
||||
torch.empty(0, device="cuda") for _ in range(NUM_SIZE0_TENSORS)
|
||||
]
|
||||
# Check that this does not segfault
|
||||
torch._foreach_mul_(size0_tensors, 0.1)
|
||||
|
||||
@ -15,9 +15,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
class TestShardUtilsDistributed(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
@ -26,7 +23,7 @@ class TestShardUtilsDistributed(FSDPTest):
|
||||
def _create_tensor(self, *size):
|
||||
# Keep everything deterministic.
|
||||
torch.manual_seed(0)
|
||||
return torch.rand(*size).to(device=device_type)
|
||||
return torch.rand(*size).cuda()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_create_chunk_sharded_tensor(self):
|
||||
@ -37,12 +34,10 @@ class TestShardUtilsDistributed(FSDPTest):
|
||||
tensor,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
torch.accelerator.device_count(),
|
||||
torch.cuda.device_count(),
|
||||
_get_default_group(),
|
||||
)
|
||||
output = (
|
||||
torch.empty(*size).to(device=device_type) if self.rank == 0 else None
|
||||
)
|
||||
output = torch.empty(*size).cuda() if self.rank == 0 else None
|
||||
sharded_tensor.gather(0, output)
|
||||
if self.rank == 0:
|
||||
self.assertEqual(tensor, output)
|
||||
@ -56,7 +51,7 @@ class TestShardUtilsDistributedDTensor(DTensorTestBase):
|
||||
def _create_tensor(self, *size):
|
||||
# Keep everything deterministic.
|
||||
torch.manual_seed(0)
|
||||
return torch.rand(*size).to(device=device_type)
|
||||
return torch.rand(*size).cuda()
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
||||
@ -50,15 +50,10 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_CUDA,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
backend = torch.distributed.get_default_backend_for_device(device_type)
|
||||
|
||||
|
||||
class BatchNormNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -137,14 +132,14 @@ class TestFSDPWrap(FSDPTest):
|
||||
|
||||
class NestedSequentialModel:
|
||||
@staticmethod
|
||||
def get_model(device=True):
|
||||
def get_model(cuda=True):
|
||||
sequential = nn.Sequential(
|
||||
nn.Linear(5, 5),
|
||||
nn.Linear(5, 5),
|
||||
nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)),
|
||||
)
|
||||
if device:
|
||||
sequential = sequential.to(device=device_type)
|
||||
if cuda:
|
||||
sequential = sequential.cuda()
|
||||
return sequential
|
||||
|
||||
@staticmethod
|
||||
@ -219,7 +214,7 @@ class TestFSDPWrap(FSDPTest):
|
||||
nested=nested, device_init_mode=device_init_mode
|
||||
)
|
||||
if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
|
||||
wrapped_fsdp = wrapped_fsdp.to(device=device_type)
|
||||
wrapped_fsdp = wrapped_fsdp.cuda()
|
||||
|
||||
wrapped_module_name = "lin1.1" if nested else "lin1"
|
||||
with self.assertRaisesRegex(
|
||||
@ -374,7 +369,7 @@ class TestFSDPWrap(FSDPTest):
|
||||
forward_prefetch=forward_prefetch,
|
||||
)
|
||||
if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
|
||||
wrapped_model = wrapped_model.to(device=device_type)
|
||||
wrapped_model = wrapped_model.cuda()
|
||||
|
||||
modules_in_fsdp_graph_order = [
|
||||
wrapped_model.module.lin1,
|
||||
@ -393,7 +388,7 @@ class TestFSDPWrap(FSDPTest):
|
||||
|
||||
# Run model a few times for sanity check.
|
||||
optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9)
|
||||
inp = torch.ones(1).to(device=device_type)
|
||||
inp = torch.ones(1).cuda()
|
||||
for _ in range(6):
|
||||
optim.zero_grad()
|
||||
loss = wrapped_model(inp).sum()
|
||||
@ -466,13 +461,13 @@ class TestAutoWrap(TestCase):
|
||||
self.assertEqual(layer.rank, 0)
|
||||
self.assertEqual(layer.world_size, 2)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test Requires CUDA or XPU")
|
||||
@unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")
|
||||
def test_always_wrap(self):
|
||||
"""
|
||||
Test to ensure that if `always_wrap_policy` is
|
||||
passed into FSDP, all submodules are wrapped.
|
||||
"""
|
||||
seq = TestFSDPWrap.NestedSequentialModel.get_model(device=True)
|
||||
seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True)
|
||||
model = FSDP(
|
||||
seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy
|
||||
)
|
||||
@ -634,7 +629,7 @@ class TestAutoWrap(TestCase):
|
||||
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
|
||||
``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
|
||||
"""
|
||||
sequential = TestFSDPWrap.NestedSequentialModel.get_model(device=False)
|
||||
sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
|
||||
my_auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy, min_num_params=40
|
||||
)
|
||||
@ -731,7 +726,7 @@ class TestAutoWrap(TestCase):
|
||||
self.assertTrue(isinstance(model.module[0], nn.Linear))
|
||||
self.assertTrue(isinstance(model.module[1], nn.ModuleList))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test Requires CUDA or XPU")
|
||||
@unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")
|
||||
@parametrize(
|
||||
"device_init_mode", [DEVICEInitMode.DEVICE_BEFORE, DEVICEInitMode.DEVICE_AFTER]
|
||||
)
|
||||
@ -748,12 +743,10 @@ class TestAutoWrap(TestCase):
|
||||
):
|
||||
return
|
||||
|
||||
device = torch.device(device_type)
|
||||
torch.accelerator.set_device_index(0)
|
||||
device = torch.device("cuda")
|
||||
torch.cuda.set_device(0)
|
||||
device_id = (
|
||||
torch.device(device_type, torch.accelerator.current_device_index())
|
||||
if use_device_id
|
||||
else None
|
||||
torch.device("cuda", torch.cuda.current_device()) if use_device_id else None
|
||||
)
|
||||
|
||||
# Random port in case the next test run quickly, same port would cause conflict.
|
||||
@ -762,18 +755,18 @@ class TestAutoWrap(TestCase):
|
||||
|
||||
file_name = tempfile.NamedTemporaryFile(delete=False).name
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
backend="nccl",
|
||||
init_method=f"{FILE_SCHEMA}_{file_name}",
|
||||
rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
|
||||
# NOTE: We move model to GPU after init with FSDP to simulate real use
|
||||
# NOTE: We move model to CUDA after init with FSDP to simulate real use
|
||||
# cases where full model cannot be loaded onto GPU, but their shards can.
|
||||
device_after_init = device_init_mode == DEVICEInitMode.DEVICE_AFTER
|
||||
cuda_after_init = device_init_mode == DEVICEInitMode.DEVICE_AFTER
|
||||
try:
|
||||
sequential = TestFSDPWrap.NestedSequentialModel.get_model(
|
||||
device=(not device_after_init)
|
||||
cuda=(not cuda_after_init)
|
||||
)
|
||||
my_auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy, min_num_params=40
|
||||
@ -785,8 +778,8 @@ class TestAutoWrap(TestCase):
|
||||
device_id=device_id,
|
||||
)
|
||||
TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
|
||||
if device_after_init:
|
||||
model = model.to(device=device_type)
|
||||
if cuda_after_init:
|
||||
model = model.cuda()
|
||||
input = torch.rand((1, 5), dtype=torch.float).to(device)
|
||||
output = model(input)
|
||||
loss = F.mse_loss(input, output)
|
||||
@ -802,7 +795,7 @@ class TestAutoWrap(TestCase):
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
|
||||
@parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
|
||||
def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
|
||||
sequential = TestFSDPWrap.NestedSequentialModel.get_model(device=False)
|
||||
sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
|
||||
ignored_modules = [sequential[1], sequential[2][0]]
|
||||
fsdp_kwargs = {
|
||||
"process_group": self.process_group,
|
||||
@ -827,7 +820,7 @@ class TestAutoWrap(TestCase):
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
|
||||
@parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
|
||||
def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
|
||||
sequential = TestFSDPWrap.NestedSequentialModel.get_model(device=False)
|
||||
sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
|
||||
ignored_modules = [sequential[1], sequential[2][0]]
|
||||
my_auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy,
|
||||
@ -890,7 +883,7 @@ class TestAutoWrap(TestCase):
|
||||
self._test_frozen_params(use_orig_params, policy)
|
||||
|
||||
def _test_frozen_params(self, use_orig_params: bool, policy: _Policy):
|
||||
model = LoraModel().to(device=device_type)
|
||||
model = LoraModel().cuda()
|
||||
msg = "layers.0.attn has both parameters with requires_grad=True and False. "
|
||||
if use_orig_params:
|
||||
msg += "We do not recommend wrapping such modules"
|
||||
|
||||
@ -1,250 +0,0 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
requires_cuda,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils.debug_mode import DebugMode
|
||||
|
||||
|
||||
@requires_cuda
|
||||
class TestDTensorDebugMode(TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.world_size = 8
|
||||
store = FakeStore()
|
||||
dist.init_process_group(
|
||||
backend="fake", rank=0, world_size=self.world_size, store=store
|
||||
)
|
||||
self.device_type = "cuda"
|
||||
|
||||
def test_debug_mode_mm(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
x = torch.randn(1, 8, requires_grad=False)
|
||||
y = torch.randn(1, 32, requires_grad=True)
|
||||
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
||||
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
|
||||
|
||||
with DebugMode() as debug_mode:
|
||||
torch.mm(x_dtensor, y_dtensor).sum()
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
|
||||
aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
|
||||
redistribute_input(1, [S(0)] -> [R])
|
||||
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
|
||||
_c10d_functional::wait_tensor(t: f32[8, 32])
|
||||
aten::mm(t: f32[1, 8], t: f32[8, 32])
|
||||
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 32][S(0)])
|
||||
aten::sum(dt: f32[8, 32][S(0)])
|
||||
aten::sum(t: f32[1, 32])""",
|
||||
)
|
||||
|
||||
def test_debug_string_inside_context(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
x = torch.randn(1, 8, requires_grad=False)
|
||||
y = torch.randn(1, 32, requires_grad=True)
|
||||
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
||||
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
|
||||
|
||||
with DebugMode() as debug_mode:
|
||||
torch.mm(x_dtensor, y_dtensor).sum()
|
||||
s0 = debug_mode.debug_string()
|
||||
s1 = debug_mode.debug_string()
|
||||
self.assertEqual(s0, s1)
|
||||
|
||||
def test_debug_mode_backward(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
x = torch.randn(1, 8, requires_grad=True)
|
||||
y = torch.randn(8, 1, requires_grad=True)
|
||||
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
||||
y_dtensor = DTensor.from_local(y, mesh, [Shard(1)], run_check=False)
|
||||
|
||||
with DebugMode() as debug_mode:
|
||||
z = x_dtensor + y_dtensor
|
||||
z.sum().backward()
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
<method 'add' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
|
||||
aten::add.Tensor(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
|
||||
redistribute_input(1, [S(1)] -> [S(0)])
|
||||
_dtensor::shard_dim_alltoall(t: f32[8, 1], 1, 0, 0)
|
||||
aten::add.Tensor(t: f32[1, 8], t: f32[1, 8])
|
||||
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)])
|
||||
aten::sum(dt: f32[8, 8][S(0)])
|
||||
aten::sum(t: f32[1, 8])
|
||||
torch._tensor.backward(dt: f32[][P], gradient=None, retain_graph=None, create_graph=False, inputs=None)
|
||||
aten::ones_like(dt: f32[][P], pin_memory=False, memory_format=torch.preserve_format)
|
||||
aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)
|
||||
aten::expand(dt: f32[][R], [8, 8])
|
||||
aten::expand(t: f32[], [8, 8])
|
||||
aten::split.Tensor(t: f32[8, 8], 1, 1)
|
||||
aten::clone(t: f32[8, 1])
|
||||
aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu)
|
||||
aten::detach(t: f32[8, 1])
|
||||
aten::split.Tensor(t: f32[8, 8], 1)
|
||||
aten::clone(t: f32[1, 8])
|
||||
aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu)
|
||||
aten::detach(t: f32[1, 8])""",
|
||||
)
|
||||
|
||||
def test_debug_mode_einsum(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).view(4, 2))
|
||||
|
||||
# Create test tensors
|
||||
a = torch.randn(16, 6, 8)
|
||||
b = torch.randn(8, 4, 4)
|
||||
|
||||
a_dt = DTensor.from_local(a, mesh, [Partial(), Replicate()], run_check=False)
|
||||
b_dt = DTensor.from_local(b, mesh, [Replicate(), Partial()], run_check=False)
|
||||
|
||||
# Capture the operator decomposition
|
||||
with DebugMode() as debug_mode:
|
||||
torch.einsum("bld,dnh->blnh", a_dt, b_dt)
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8][P, R], dt: f32[8, 4, 4][R, P])
|
||||
aten::unsqueeze(dt: f32[16, 6, 8][P, R], 3)
|
||||
aten::unsqueeze(t: f32[16, 6, 8], 3)
|
||||
aten::unsqueeze(dt: f32[16, 6, 8, 1][P, R], 4)
|
||||
aten::unsqueeze(t: f32[16, 6, 8, 1], 4)
|
||||
aten::permute(dt: f32[16, 6, 8, 1, 1][P, R], [0, 1, 3, 4, 2])
|
||||
aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2])
|
||||
aten::unsqueeze(dt: f32[8, 4, 4][R, P], 3)
|
||||
aten::unsqueeze(t: f32[8, 4, 4], 3)
|
||||
aten::unsqueeze(dt: f32[8, 4, 4, 1][R, P], 4)
|
||||
aten::unsqueeze(t: f32[8, 4, 4, 1], 4)
|
||||
aten::permute(dt: f32[8, 4, 4, 1, 1][R, P], [3, 4, 1, 2, 0])
|
||||
aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0])
|
||||
aten::permute(dt: f32[16, 6, 1, 1, 8][P, R], [0, 1, 4, 2, 3])
|
||||
aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3])
|
||||
aten::view(dt: f32[16, 6, 8, 1, 1][P, R], [1, 96, 8])
|
||||
aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8])
|
||||
aten::permute(dt: f32[1, 1, 4, 4, 8][R, P], [4, 2, 3, 0, 1])
|
||||
aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1])
|
||||
aten::view(dt: f32[8, 4, 4, 1, 1][R, P], [1, 8, 16])
|
||||
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
|
||||
aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
|
||||
redistribute_input(0, [P, R] -> [S(2), S(2)])
|
||||
aten::chunk(t: f32[1, 96, 8], 4, 2)
|
||||
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
|
||||
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
|
||||
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
|
||||
aten::chunk(t: f32[1, 96, 2], 2, 2)
|
||||
aten::clone(t: f32[1, 96, 1])
|
||||
redistribute_input(1, [R, P] -> [S(1), S(1)])
|
||||
aten::chunk(t: f32[1, 8, 16], 4, 1)
|
||||
aten::clone(t: f32[1, 2, 16])
|
||||
aten::chunk(t: f32[1, 2, 16], 2, 1)
|
||||
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
|
||||
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
|
||||
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
|
||||
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
|
||||
aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
|
||||
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
|
||||
aten::permute(dt: f32[16, 6, 1, 4, 4][P, P], [0, 1, 3, 4, 2])
|
||||
aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2])
|
||||
aten::view(dt: f32[16, 6, 4, 4, 1][P, P], [16, 6, 4, 4])
|
||||
aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])""",
|
||||
)
|
||||
|
||||
def test_real_tensor(self):
|
||||
x = torch.randn(8, 8, 8)
|
||||
linear = torch.nn.Linear(8, 8)
|
||||
|
||||
with DebugMode() as debug_mode:
|
||||
linear(x).sum()
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
torch._C._nn.linear(t: f32[8, 8, 8], t: f32[8, 8], t: f32[8])
|
||||
aten::view(t: f32[8, 8, 8], [64, 8])
|
||||
aten::t(t: f32[8, 8])
|
||||
aten::addmm(t: f32[8], t: f32[64, 8], t: f32[8, 8])
|
||||
aten::view(t: f32[64, 8], [8, 8, 8])
|
||||
<method 'sum' of 'torch._C.TensorBase' objects>(t: f32[8, 8, 8])
|
||||
aten::sum(t: f32[8, 8, 8])""",
|
||||
)
|
||||
|
||||
def test_fake_tensor(self):
|
||||
with FakeTensorMode():
|
||||
x = torch.randn(8, 8)
|
||||
y = torch.randn(8, 8, 8)
|
||||
|
||||
with DebugMode(record_faketensor=True) as debug_mode:
|
||||
torch.matmul(y, x)
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
torch.matmul(ft: f32[8, 8, 8], ft: f32[8, 8])
|
||||
aten::view(ft: f32[8, 8, 8], [64, 8])
|
||||
aten::mm(ft: f32[64, 8], ft: f32[8, 8])
|
||||
aten::_unsafe_view(ft: f32[64, 8], [8, 8, 8])""",
|
||||
)
|
||||
|
||||
@parametrize("has_inner_mode", [True, False])
|
||||
@parametrize("has_outer_mode", [True, False])
|
||||
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):
|
||||
class DummyTorchDispatchMode1(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
class DummyTorchDispatchMode2(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
x = torch.randn(1, 8, requires_grad=True)
|
||||
y = torch.randn(1, 32, requires_grad=True)
|
||||
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
||||
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
|
||||
|
||||
inner_mode = (
|
||||
DummyTorchDispatchMode1() if has_inner_mode else contextlib.nullcontext()
|
||||
)
|
||||
outer_mode = (
|
||||
DummyTorchDispatchMode2() if has_outer_mode else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with outer_mode:
|
||||
with DebugMode() as debug_mode:
|
||||
with inner_mode:
|
||||
torch.mm(x_dtensor, y_dtensor)
|
||||
|
||||
self.assertTrue(
|
||||
"redistribute_input(1, [S(0)] -> [R])" in debug_mode.debug_string()
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -28,7 +28,6 @@ from torch.distributed.tensor.parallel import parallelize_module
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_mask_mod_signature,
|
||||
AuxRequest,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
)
|
||||
@ -575,8 +574,8 @@ class RingFlexAttentionTest(DTensorTestBase):
|
||||
device=self.device_type,
|
||||
)
|
||||
|
||||
expect_out, expect_aux = compiled_flex_attention(
|
||||
q, k, v, block_mask=block_mask, return_aux=AuxRequest(lse=True)
|
||||
expect_out, expect_lse = compiled_flex_attention(
|
||||
q, k, v, block_mask=block_mask, return_lse=True
|
||||
)
|
||||
expect_out.sum().backward()
|
||||
|
||||
@ -636,12 +635,12 @@ class RingFlexAttentionTest(DTensorTestBase):
|
||||
cp_k.requires_grad = True
|
||||
cp_v.requires_grad = True
|
||||
|
||||
cp_out, cp_aux = compiled_flex_attention(
|
||||
cp_out, cp_lse = compiled_flex_attention(
|
||||
cp_q,
|
||||
cp_k,
|
||||
cp_v,
|
||||
block_mask=cp_block_mask,
|
||||
return_aux=AuxRequest(lse=True),
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
# check block_mask rewrite doesn't escape to the outside
|
||||
@ -658,11 +657,9 @@ class RingFlexAttentionTest(DTensorTestBase):
|
||||
cp_v.requires_grad = False
|
||||
|
||||
# unshard the output
|
||||
cp_out, cp_lse = context_parallel_unshard(
|
||||
device_mesh, [cp_out, cp_aux.lse], [2, 2]
|
||||
)
|
||||
cp_out, cp_lse = context_parallel_unshard(device_mesh, [cp_out, cp_lse], [2, 2])
|
||||
torch.testing.assert_close(cp_out, expect_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(cp_lse, expect_aux.lse, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(cp_lse, expect_lse, atol=atol, rtol=rtol)
|
||||
|
||||
# unshard the gradient
|
||||
cp_q_grad, cp_k_grad, cp_v_grad = context_parallel_unshard(
|
||||
|
||||
@ -211,8 +211,8 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
|
||||
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
|
||||
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
|
||||
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
|
||||
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
|
||||
add_1 = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
|
||||
view_1 = torch.ops.aten.view.default(add_1, [4, 4]); add_1 = None
|
||||
return (view_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -317,9 +317,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
@skipIfHpu
|
||||
@unittest.skip(
|
||||
"DTensor + dynamic fails - s77 + 8 is not tracked with proxy .. proxy_tensor.PythonKeyTracer"
|
||||
)
|
||||
def test_dtensor_dynamic_slice(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
@ -361,9 +358,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
@unittest.skip(
|
||||
"DTensor + dynamic fails - s77 + 8 is not tracked with proxy .. proxy_tensor.PythonKeyTracer"
|
||||
)
|
||||
def test_dtensor_dynamic_cat(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ import warnings
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.testing._internal.common_methods_invocations as common_ops
|
||||
from torch.distributed.tensor import DTensor, init_device_mesh
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||
from torch.overrides import resolve_name
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
@ -159,6 +159,7 @@ dtensor_fails = {
|
||||
xfail("geometric"),
|
||||
xfail("geqrf"),
|
||||
xfail("grid_sampler_2d"),
|
||||
xfail("gradient"),
|
||||
xfail("heaviside"),
|
||||
xfail("histogram"),
|
||||
xfail("histogramdd"),
|
||||
@ -417,6 +418,8 @@ dtensor_fails = {
|
||||
xfail("tensor_split"),
|
||||
xfail("to_sparse"),
|
||||
xfail("trace"),
|
||||
xfail("trapezoid"),
|
||||
xfail("trapz"),
|
||||
xfail("triangular_solve"),
|
||||
xfail("unbind"),
|
||||
xfail("unbind_copy"),
|
||||
@ -508,7 +511,7 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
def run_opinfo_test(
|
||||
self, dtype, op, requires_grad=True, sample_inputs_filter=lambda s: True
|
||||
):
|
||||
self.mesh = init_device_mesh(DEVICE_TYPE, (self.world_size,))
|
||||
self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size))
|
||||
|
||||
# test each op with dist tensor inputs and normal inputs
|
||||
def test():
|
||||
@ -633,7 +636,7 @@ class TestDTensorOps(DTensorOpTestBase):
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"{str(e)}\n\nfailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
|
||||
f"failed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
|
||||
) from e
|
||||
return rs
|
||||
|
||||
|
||||
@ -43,7 +43,6 @@ from torch.testing._internal.common_utils import (
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TEST_XPU,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
@ -64,8 +63,6 @@ else:
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
|
||||
|
||||
def gpus_for_rank(world_size):
|
||||
"""Multigpu tests are designed to simulate the multi nodes with multi
|
||||
@ -73,9 +70,8 @@ def gpus_for_rank(world_size):
|
||||
On a single node, all visible GPUs are evenly
|
||||
divided to subsets, each process only uses a subset.
|
||||
"""
|
||||
device_count = torch.accelerator.device_count()
|
||||
visible_devices = list(range(device_count))
|
||||
gpus_per_process = device_count // world_size
|
||||
visible_devices = list(range(torch.cuda.device_count()))
|
||||
gpus_per_process = torch.cuda.device_count() // world_size
|
||||
gpus_for_rank = []
|
||||
for rank in range(world_size):
|
||||
gpus_for_rank.append(
|
||||
@ -405,7 +401,7 @@ class CommonDistributedDataParallelTest:
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
)
|
||||
|
||||
input = torch.randn(global_batch_size, 2).to(devices[0])
|
||||
input = torch.randn(global_batch_size, 2).cuda(devices[0])
|
||||
target = torch.randn(global_batch_size, 4)
|
||||
|
||||
return model, ddp_model, input, target
|
||||
@ -439,10 +435,10 @@ class CommonDistributedDataParallelTest:
|
||||
allow_none_grads=False,
|
||||
):
|
||||
# to reproduce the same training results
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
torch.cuda.set_device(self.rank)
|
||||
torch.manual_seed(31415)
|
||||
model = copy.deepcopy(input_model).to(device_type)
|
||||
ddp_model = copy.deepcopy(input_model).to(device_type)
|
||||
model = copy.deepcopy(input_model).cuda()
|
||||
ddp_model = copy.deepcopy(input_model).cuda()
|
||||
ddp_model = nn.parallel.DistributedDataParallel(
|
||||
ddp_model,
|
||||
bucket_cap_mb=1,
|
||||
@ -558,8 +554,8 @@ class CommonDistributedDataParallelTest:
|
||||
def _prepare_dummy_data(self):
|
||||
ddp_bs = 16
|
||||
bs = ddp_bs * self.world_size
|
||||
input = torch.rand((bs, 20), device=device_type, requires_grad=True)
|
||||
target = torch.randn((bs, 20), device=device_type)
|
||||
input = torch.rand((bs, 20), device="cuda", requires_grad=True)
|
||||
target = torch.randn((bs, 20), device="cuda")
|
||||
offset = self.rank * ddp_bs
|
||||
ddp_input = input[offset : offset + ddp_bs]
|
||||
ddp_target = target[offset : offset + ddp_bs]
|
||||
@ -719,7 +715,7 @@ class CommonDistributedDataParallelTest:
|
||||
Test that checkpointing with weight sharing works.
|
||||
"""
|
||||
process_group = self._get_process_group()
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
torch.cuda.set_device(self.rank)
|
||||
for use_bucket_view, static_graph in product((False, True), (False, True)):
|
||||
torch.manual_seed(31415)
|
||||
l1 = nn.Linear(20, 20)
|
||||
@ -742,7 +738,7 @@ class CommonDistributedDataParallelTest:
|
||||
same layer twice and having weights shared across layers.
|
||||
"""
|
||||
process_group = self._get_process_group()
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
torch.cuda.set_device(self.rank)
|
||||
for use_bucket_view in (True, False):
|
||||
self._test_ddp_checkpointing(
|
||||
self.CheckpointTwiceModuleWeightSharing(),
|
||||
@ -1166,7 +1162,7 @@ class AbstractCommTest:
|
||||
|
||||
# Verify sequence numbers are appropriately incremented
|
||||
for i in range(10):
|
||||
t = torch.ones(1, device=device_type)
|
||||
t = torch.ones(1, device=torch.cuda.current_device())
|
||||
dist.all_reduce(t, group=process_group)
|
||||
if not c10d._rank_not_in_group(process_group):
|
||||
seq_num = self._verify_sequence_number_across_pg(
|
||||
@ -1197,7 +1193,7 @@ class AbstractCommTest:
|
||||
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
|
||||
|
||||
def _test_sequence_num_incremented_default_group(self, backend_name):
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
torch.cuda.set_device(self.rank)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend_name,
|
||||
@ -1211,7 +1207,7 @@ class AbstractCommTest:
|
||||
)
|
||||
|
||||
def _test_sequence_num_incremented_subgroup(self, backend_name):
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
torch.cuda.set_device(self.rank)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
dist.init_process_group(
|
||||
backend_name,
|
||||
@ -1266,8 +1262,8 @@ class AbstractCommTest:
|
||||
in_group_ranks = list(filter(lambda x: x % 2 == 0, range(self.world_size)))
|
||||
group = dist.new_group(in_group_ranks)
|
||||
|
||||
x = torch.zeros(2, 2).to(self.rank)
|
||||
xs = [torch.zeros(2, 2).to(self.rank) for _ in range(len(in_group_ranks))]
|
||||
x = torch.zeros(2, 2).cuda(self.rank)
|
||||
xs = [torch.zeros(2, 2).cuda(self.rank) for _ in range(len(in_group_ranks))]
|
||||
if self.rank not in in_group_ranks:
|
||||
msg = ".*{}.*does not belong to.*"
|
||||
with self.assertWarnsOnceRegex(UserWarning, msg.format("all_gather")):
|
||||
@ -1396,7 +1392,7 @@ class AbstractCommTest:
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
|
||||
device = "cuda" if backend == "nccl" else "cpu"
|
||||
# test alltoall_base
|
||||
tensor = torch.tensor([1, 0, 0, 1], dtype=torch.bool, device=device)
|
||||
zeros = torch.tensor([0, 0, 0, 0], dtype=torch.bool, device=device)
|
||||
@ -1578,8 +1574,8 @@ class CommTest(AbstractCommTest, MultiProcessTestCase):
|
||||
|
||||
class DummyWork(dist._Work):
|
||||
def wait(self, timeout=5.0):
|
||||
if torch.accelerator.is_available():
|
||||
torch.accelerator.current_stream().synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.current_stream().synchronize()
|
||||
return True
|
||||
|
||||
|
||||
@ -1794,18 +1790,6 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
|
||||
("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"),
|
||||
]
|
||||
|
||||
if TEST_XPU:
|
||||
# Override backend_config_strings_and_expected_values for Intel GPU.
|
||||
backend_config_strings_and_expected_values[4:10] = [
|
||||
(dist.Backend.DUMMY, "cpu:dummy,cuda:dummy,xpu:dummy"),
|
||||
("DUMMY", "cpu:dummy,cuda:dummy,xpu:dummy"),
|
||||
("dummy", "cpu:dummy,cuda:dummy,xpu:dummy"),
|
||||
("cpu:dummy,xpu:dummy", "cpu:dummy,xpu:dummy"),
|
||||
("cpu:dummy,xpu:xccl", "cpu:dummy,xpu:xccl"),
|
||||
("cpu:gloo,xpu:dummy", "cpu:gloo,xpu:dummy"),
|
||||
("cpu:gloo,xpu:xccl", "cpu:gloo,xpu:xccl"),
|
||||
]
|
||||
|
||||
for config_str, expected_value in backend_config_strings_and_expected_values:
|
||||
with self.subTest(config_str):
|
||||
# ensures these configs strings are valid and no ValueError is raised
|
||||
@ -1816,8 +1800,6 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
|
||||
invalid_backend_config_strings = [
|
||||
"cpu:gloo,cuda:nccl,", # trailing comma
|
||||
"cpu:gloo,cuda:nccl,cpu:dummy", # duplicate device
|
||||
"cpu:gloo,xpu:xccl,", # trailing comma
|
||||
"cpu:gloo,xpu:xccl,cpu:dummy", # duplicate device
|
||||
]
|
||||
for config_str in invalid_backend_config_strings:
|
||||
with self.subTest(config_str):
|
||||
@ -1832,7 +1814,7 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "6789"
|
||||
dist.init_process_group(
|
||||
"cpu:dummy,cuda:dummy,xpu:dummy", rank=self.rank, world_size=self.world_size
|
||||
"cpu:dummy,cuda:dummy", rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
|
||||
# test all_gather
|
||||
@ -2071,7 +2053,7 @@ dist.init_process_group(rank=0, world_size=1, store=dist.HashStore())
|
||||
# correctly dispatched
|
||||
|
||||
# TODO: this will be updated in the future to not be backend specific
|
||||
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
|
||||
device = "cuda" if backend == "nccl" else "cpu"
|
||||
# ensure supported devices (cpu, cuda) succeeds during dispatch call
|
||||
tensor = torch.zeros(2, 2, device=torch.device(device))
|
||||
# multi tensor collectives
|
||||
@ -2137,7 +2119,7 @@ dist.init_process_group(rank=0, world_size=1, store=dist.HashStore())
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
)
|
||||
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
|
||||
device = "cuda" if backend == "nccl" else "cpu"
|
||||
# test alltoall_base
|
||||
input_tensor = torch.ones(2, 2, device=torch.device(device))
|
||||
output_tensor = torch.zeros(2, 2, device=torch.device(device))
|
||||
@ -2269,9 +2251,8 @@ class LocalRankTest(MultiProcessTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if device_type != "cpu":
|
||||
assert not torch.get_device_module()._initialized, (
|
||||
"test_distributed must not have initialized {device_type} context on main process"
|
||||
)
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
||||
@ -21,15 +21,15 @@ from torch.distributed._functional_collectives import (
|
||||
reduce_scatter_tensor,
|
||||
reduce_scatter_tensor_coalesced,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||
from torch.testing._internal.common_device_type import e4m3_type
|
||||
from torch.testing._internal.common_cuda import SM90OrLater
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
requires_accelerator_dist_backend,
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
@ -59,7 +59,7 @@ if not dist.is_available():
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
@requires_nccl()
|
||||
class TestWithNCCL(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
@ -75,15 +75,13 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(self.rank)
|
||||
return torch.device(f"cuda:{self.rank}")
|
||||
|
||||
def _init_process_group(self) -> None:
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
torch.cuda.set_device(self.device)
|
||||
store = dist.FileStore(self.file_name, self.world_size)
|
||||
backend = dist.get_default_backend_for_device(self.device.type)
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
backend="nccl",
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
@ -275,7 +273,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
)
|
||||
# check memory leak
|
||||
for i in range(1, 10):
|
||||
mem_usage[i] = torch.accelerator.max_memory_allocated()
|
||||
mem_usage[i] = torch.cuda.max_memory_allocated()
|
||||
compiled(arg)
|
||||
|
||||
assert mem_usage[9] == mem_usage[8]
|
||||
@ -372,16 +370,14 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_all_to_all_single(self) -> None:
|
||||
self._init_process_group()
|
||||
torch.accelerator.set_device_index(self.rank)
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
torch.manual_seed(42)
|
||||
send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))
|
||||
|
||||
input_split_sizes = send_sz_matrix[self.rank].tolist()
|
||||
output_split_sizes = send_sz_matrix[:, self.rank].tolist()
|
||||
input = torch.full((sum(input_split_sizes),), float(self.rank)).to(
|
||||
self.device.type
|
||||
)
|
||||
input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda()
|
||||
|
||||
output = torch.ops._c10d_functional.all_to_all_single(
|
||||
input,
|
||||
@ -392,7 +388,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
output = torch.ops._c10d_functional.wait_tensor(output)
|
||||
expect = torch.cat(
|
||||
[
|
||||
torch.full((sz,), float(rank)).to(self.device.type)
|
||||
torch.full((sz,), float(rank)).cuda()
|
||||
for rank, sz in enumerate(output_split_sizes)
|
||||
]
|
||||
)
|
||||
@ -468,7 +464,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
@fresh_cache()
|
||||
def test_threading(self):
|
||||
self._init_process_group()
|
||||
device = self.device
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
|
||||
def func(arg: torch.Tensor) -> torch.Tensor:
|
||||
buf0 = arg + 42
|
||||
@ -505,9 +501,10 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FP8,
|
||||
"_scaled_mm currently only supports sm>=90 on cuda and gfx94/95 on ROCm",
|
||||
not SM90OrLater,
|
||||
"_scaled_mm currently only supports sm>=90",
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@fresh_cache()
|
||||
@ -516,9 +513,10 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
|
||||
def scale(t):
|
||||
scale = (
|
||||
torch.finfo(e4m3_type).max / t.abs().amax(dim=-1, keepdim=True).float()
|
||||
torch.finfo(torch.float8_e4m3fn).max
|
||||
/ t.abs().amax(dim=-1, keepdim=True).float()
|
||||
)
|
||||
t = t.mul(scale).to(e4m3_type)
|
||||
t = t.mul(scale).to(torch.float8_e4m3fn)
|
||||
return t, scale
|
||||
|
||||
def fp8_rowwise_backward(in_, w, out_grad):
|
||||
@ -548,9 +546,9 @@ class TestWithNCCL(MultiProcessTestCase):
|
||||
return in_grad, w_grad
|
||||
|
||||
m, n, k = 128, 256, 64
|
||||
in_ = torch.randn((m, k), device=self.device.type, dtype=torch.bfloat16)
|
||||
w = torch.randn((n, k), device=self.device.type, dtype=torch.bfloat16)
|
||||
out_grad = torch.randn((m, n), device=self.device.type, dtype=torch.bfloat16)
|
||||
in_ = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
w = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
out_grad = torch.randn((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
eager_in_grad, eager_w_grad = fp8_rowwise_backward(in_, w, out_grad)
|
||||
compile_in_grad, compile_w_grad = torch.compile(fp8_rowwise_backward)(
|
||||
@ -779,8 +777,7 @@ class CompileTest(TestCase):
|
||||
|
||||
self.rank = 0
|
||||
self.world_size = 2
|
||||
torch.accelerator.set_device_index(0)
|
||||
self.device = torch.accelerator.current_accelerator()
|
||||
torch.cuda.set_device("cuda:0")
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(
|
||||
@ -806,7 +803,7 @@ class CompileTest(TestCase):
|
||||
ar1 = funcol.wait_tensor(ar1)
|
||||
return ar0, ar1
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device)
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
compiled = torch.compile(func)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
@ -839,7 +836,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -854,7 +851,7 @@ class CompileTest(TestCase):
|
||||
ar1 = [funcol.wait_tensor(out) for out in ar1]
|
||||
return ar0, ar1
|
||||
|
||||
args = [torch.rand(4, 4, device=self.device.type) for _ in range(2)]
|
||||
args = [torch.rand(4, 4, device="cuda") for _ in range(2)]
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, args)
|
||||
buf0, buf1, buf2, buf3 = find_buffer_assignments(code)
|
||||
@ -884,7 +881,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -895,7 +892,7 @@ class CompileTest(TestCase):
|
||||
ar0 = funcol.wait_tensor(ar0)
|
||||
return ar0
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
compiled = torch.compile(func)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
@ -920,7 +917,7 @@ class CompileTest(TestCase):
|
||||
# Expect allocation
|
||||
return ar0
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device.type).T
|
||||
arg = torch.rand(4, 4, device="cuda").T
|
||||
compiled = torch.compile(func)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
@ -951,7 +948,7 @@ class CompileTest(TestCase):
|
||||
buf2 = torch.mm(arg, buf1)
|
||||
return buf1, buf2
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
buf0, buf1 = find_buffer_assignments(code)
|
||||
@ -981,7 +978,7 @@ class CompileTest(TestCase):
|
||||
ag0 = funcol.wait_tensor(ag0)
|
||||
return ag0
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
@ -998,7 +995,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1008,7 +1005,7 @@ class CompileTest(TestCase):
|
||||
ag0 = [funcol.wait_tensor(out) for out in ag0]
|
||||
return ag0
|
||||
|
||||
args = [torch.rand(4, 4, device=self.device.type) for _ in range(4)]
|
||||
args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, args)
|
||||
(
|
||||
@ -1032,7 +1029,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
|
||||
@fresh_cache()
|
||||
@ -1042,7 +1039,7 @@ class CompileTest(TestCase):
|
||||
return funcol.wait_tensor(t)
|
||||
|
||||
# Test aoti
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
@ -1054,7 +1051,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1064,7 +1061,7 @@ class CompileTest(TestCase):
|
||||
rs0 = funcol.wait_tensor(rs0)
|
||||
return rs0
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
(
|
||||
@ -1080,7 +1077,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1092,7 +1089,7 @@ class CompileTest(TestCase):
|
||||
rs0 = [funcol.wait_tensor(out) for out in rs0]
|
||||
return rs0
|
||||
|
||||
args = [torch.rand(4, 4, device=self.device.type) for _ in range(4)]
|
||||
args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, args)
|
||||
(
|
||||
@ -1116,7 +1113,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (args,))
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1145,9 +1142,7 @@ class CompileTest(TestCase):
|
||||
|
||||
input_split_sizes = send_sz_matrix[self.rank]
|
||||
output_split_sizes = send_sz_matrix[:, self.rank].contiguous()
|
||||
input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).to(
|
||||
self.device.type
|
||||
)
|
||||
input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda()
|
||||
|
||||
with torch._dynamo.config.patch(
|
||||
dynamic_shapes=True,
|
||||
@ -1181,7 +1176,7 @@ class CompileTest(TestCase):
|
||||
br1 = funcol.wait_tensor(br1)
|
||||
return br0, br1
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
compiled = torch.compile(func)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
@ -1204,7 +1199,7 @@ class CompileTest(TestCase):
|
||||
|
||||
# Test aoti
|
||||
AOTIRunnerUtil.run(func, (arg,))
|
||||
torch.accelerator.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_cache()
|
||||
@ -1219,7 +1214,7 @@ class CompileTest(TestCase):
|
||||
ar1 = funcol.wait_tensor(ar1)
|
||||
return ar0, ar1
|
||||
|
||||
arg = torch.rand(4, 4, device=self.device.type)
|
||||
arg = torch.rand(4, 4, device="cuda")
|
||||
compiled = torch.compile(func, fullgraph=True)
|
||||
|
||||
code = run_and_get_triton_code(compiled, arg)
|
||||
|
||||
@ -3145,24 +3145,19 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
nccl_debug_file = tempfile.NamedTemporaryFile()
|
||||
nccl_env = {
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "1",
|
||||
"NCCL_ALGO": "NVLS",
|
||||
"NCCL_DEBUG": "INFO",
|
||||
"NCCL_DEBUG_SUBSYS": "NVLS",
|
||||
"NCCL_DEBUG_FILE": nccl_debug_file.name,
|
||||
}
|
||||
os.environ["NCCL_ALGO"] = "NVLS"
|
||||
os.environ["NCCL_DEBUG"] = "INFO"
|
||||
os.environ["NCCL_DEBUG_SUBSYS"] = "NVLS"
|
||||
if torch.cuda.nccl.version() >= (2, 24, 3):
|
||||
nccl_env["NCCL_DEBUG_SUBSYS"] = "REG,TUNING"
|
||||
self.env_patcher = mock.patch.dict(os.environ, nccl_env)
|
||||
self.env_patcher.start()
|
||||
os.environ["NCCL_DEBUG_SUBSYS"] = "REG,TUNING"
|
||||
os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
self.env_patcher.stop()
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -27,7 +26,7 @@ from torch.distributed.tensor._collective_utils import (
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import _Partial, Shard
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_XPU
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -36,10 +35,6 @@ from torch.testing._internal.distributed.fake_pg import FakeProcessGroup, FakeSt
|
||||
from torch.utils._typing_utils import not_none
|
||||
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
device_count = torch.accelerator.device_count()
|
||||
|
||||
|
||||
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
|
||||
os.environ["MASTER_ADDR"] = addr
|
||||
os.environ["MASTER_PORT"] = port
|
||||
@ -49,7 +44,6 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
|
||||
os.environ["LOCAL_RANK"] = f"{local_rank}"
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
|
||||
class DeviceMeshTestGlooBackend(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
@ -79,16 +73,14 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
|
||||
|
||||
# Set the device on each process before DeviceMesh constructor,
|
||||
# and device to be different than the default world rank
|
||||
torch.accelerator.set_device_index((self.rank + 2) % self.world_size)
|
||||
torch.cuda.set_device((self.rank + 2) % self.world_size)
|
||||
_set_env_var(world_size=self.world_size, rank=self.rank)
|
||||
DeviceMesh(self.device_type, mesh_tensor)
|
||||
self.assertTrue(is_initialized())
|
||||
|
||||
# check that the device is set to the correct device
|
||||
# and respect the previous set_device calls
|
||||
self.assertEqual(
|
||||
torch.accelerator.current_device_idx(), (self.rank + 2) % self.world_size
|
||||
)
|
||||
self.assertEqual(torch.cuda.current_device(), (self.rank + 2) % self.world_size)
|
||||
self.destroy_pg()
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@ -109,7 +101,7 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
|
||||
|
||||
# check that the device is set to the correct device
|
||||
# and respect the LOCAL_RANK env var
|
||||
self.assertEqual(torch.accelerator.current_device_idx(), local_rank)
|
||||
self.assertEqual(torch.cuda.current_device(), local_rank)
|
||||
self.destroy_pg()
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@ -128,7 +120,7 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
|
||||
self.assertTrue(is_initialized())
|
||||
|
||||
# check that the device is set to the correct device
|
||||
self.assertEqual(torch.accelerator.current_device_idx(), self.rank)
|
||||
self.assertEqual(torch.cuda.current_device(), self.rank)
|
||||
self.destroy_pg()
|
||||
|
||||
|
||||
@ -230,7 +222,7 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_device_mesh_2d(self):
|
||||
mesh_tensor = torch.arange(4).reshape(2, 2)
|
||||
# construct a device mesh for self.device_type
|
||||
# construct a cuda device mesh
|
||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||
|
||||
# check all dim groups
|
||||
@ -268,11 +260,7 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
def test_fake_pg_device_mesh(self):
|
||||
fake_store = FakeStore()
|
||||
init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
|
||||
device_type = (
|
||||
torch.accelerator.current_accelerator().type
|
||||
if torch.accelerator.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
mesh = DeviceMesh(device_type, torch.arange(self.world_size))
|
||||
|
||||
local_tensor = torch.randn(2, 8)
|
||||
@ -312,7 +300,7 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]"
|
||||
with self.assertRaisesRegex(ValueError, regex):
|
||||
DeviceMesh.from_group(
|
||||
global_pg, device_type, invalid_mesh, mesh_dim_names=("dim0", "dim1")
|
||||
global_pg, "cuda", invalid_mesh, mesh_dim_names=("dim0", "dim1")
|
||||
)
|
||||
|
||||
device_mesh = init_device_mesh(self.device_type, (2, 2))
|
||||
@ -332,16 +320,12 @@ class DeviceMeshTest(DTensorTestBase):
|
||||
# test init_device_mesh with an invalid device type that contains a GPU index
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
init_device_mesh(
|
||||
f"{device_type}:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
|
||||
"cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_set_mesh_dim_group_options(self):
|
||||
device_type = (
|
||||
torch.accelerator.current_accelerator().type
|
||||
if torch.accelerator.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
_mesh_resources._set_mesh_dim_group_options(1, "fake", None)
|
||||
|
||||
mesh_tensor = torch.arange(4).reshape(2, 2)
|
||||
@ -357,7 +341,7 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
def test_device_mesh_nd(self):
|
||||
# construct a device mesh for self.device_type
|
||||
# construct a cuda device mesh
|
||||
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
|
||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||
|
||||
@ -726,9 +710,7 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"):
|
||||
mesh_dim_names = ("DP", "TP")
|
||||
mesh = init_device_mesh(
|
||||
self.device_type,
|
||||
(2, 4),
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
mesh[child_mesh_dim_name]
|
||||
|
||||
@ -956,9 +938,7 @@ class TestMeshEnv(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_get_root_mesh(self):
|
||||
mesh_3d = init_device_mesh(
|
||||
self.device_type,
|
||||
(2, 2, 2),
|
||||
mesh_dim_names=("dp", "cp", "tp"),
|
||||
self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp")
|
||||
)
|
||||
|
||||
dp_cp_mesh = mesh_3d["dp", "cp"]
|
||||
@ -1006,9 +986,7 @@ class TestMeshEnv(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_get_all_submeshes(self):
|
||||
mesh_2d = init_device_mesh(
|
||||
self.device_type,
|
||||
(2, 4),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
self.device_type, (2, 4), mesh_dim_names=("replicate", "shard")
|
||||
)
|
||||
all_submeshes = _mesh_resources._get_all_submeshes(mesh_2d, "replicate")
|
||||
self.assertEqual(len(all_submeshes), 4)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user