Compare commits

..

1 Commits

Author SHA1 Message Date
46eaea0232 update test_quantization tests to run weekly 2025-09-23 09:02:34 -07:00
280 changed files with 2282 additions and 6699 deletions

View File

@ -0,0 +1,40 @@
#!/bin/bash
# This is where the local pytorch install in the docker image is located
pt_checkout="/var/lib/jenkins/workspace"
source "$pt_checkout/.ci/pytorch/common_utils.sh"
echo "functorch_doc_push_script.sh: Invoked with $*"
set -ex -o pipefail
version=${DOCS_VERSION:-nightly}
echo "version: $version"
# Build functorch docs
pushd $pt_checkout/functorch/docs
make html
popd
git clone https://github.com/pytorch/functorch -b gh-pages --depth 1 functorch_ghpages
pushd functorch_ghpages
if [ "$version" == "main" ]; then
version=nightly
fi
git rm -rf "$version" || true
mv "$pt_checkout/functorch/docs/build/html" "$version"
git add "$version" || true
git status
git config user.email "soumith+bot@pytorch.org"
git config user.name "pytorchbot"
# If there aren't changes, don't make a commit; push is no-op
git commit -m "Generate Python docs from pytorch/pytorch@${GITHUB_SHA}" || true
git status
if [[ "${WITH_PUSH:-}" == true ]]; then
git push -u origin gh-pages
fi
popd

View File

@ -59,7 +59,7 @@ test_python_shard() {
setup_test_python
time python test/run_test.py --verbose --exclude-jit-executor --exclude-distributed-tests --shard "$1" "$NUM_TEST_SHARDS"
time python test/run_test.py --verbose --exclude-jit-executor --exclude-distributed-tests --exclude-quantization-tests --shard "$1" "$NUM_TEST_SHARDS"
assert_git_not_dirty
}

View File

@ -1,25 +0,0 @@
From 6e08c9d08e9de59c7af28b720289debbbd384764 Mon Sep 17 00:00:00 2001
From: Michael Wang <13521008+isVoid@users.noreply.github.com>
Date: Tue, 1 Apr 2025 17:28:05 -0700
Subject: [PATCH] Avoid bumping certain driver API to avoid future breakage
(#185)
Co-authored-by: isVoid <isVoid@users.noreply.github.com>
---
numba_cuda/numba/cuda/cudadrv/driver.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py
index 1641bf77..233e9ed7 100644
--- a/numba_cuda/numba/cuda/cudadrv/driver.py
+++ b/numba_cuda/numba/cuda/cudadrv/driver.py
@@ -365,6 +365,9 @@ def _find_api(self, fname):
else:
variants = ('_v2', '')
+ if fname in ("cuCtxGetDevice", "cuCtxSynchronize"):
+ return getattr(self.lib, fname)
+
for variant in variants:
try:
return getattr(self.lib, f'{fname}{variant}')

View File

@ -32,16 +32,6 @@ if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /v
git config --global --add safe.directory /var/lib/jenkins/workspace
fi
# Patch numba to avoid CUDA-13 crash, see https://github.com/pytorch/pytorch/issues/162878
NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true)
if [ -n "$NUMBA_CUDA_DIR" ]; then
NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch"
pushd "$NUMBA_CUDA_DIR"
patch -p4 <"$NUMBA_PATCH"
popd
fi
echo "Environment variables:"
env
@ -322,14 +312,14 @@ test_python_shard() {
# modify LD_LIBRARY_PATH to ensure it has the conda env.
# This set of tests has been shown to be buggy without it for the split-build
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests $INCLUDE_CLAUSE --shard "$1" "$NUM_TEST_SHARDS" --verbose $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests --exclude-quantization-tests $INCLUDE_CLAUSE --shard "$1" "$NUM_TEST_SHARDS" --verbose $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}
test_python() {
# shellcheck disable=SC2086
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests $INCLUDE_CLAUSE --verbose $PYTHON_TEST_EXTRA_OPTION
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests --exclude-quantization-tests $INCLUDE_CLAUSE --verbose $PYTHON_TEST_EXTRA_OPTION
assert_git_not_dirty
}
@ -384,6 +374,7 @@ test_dynamo_wrapped_shard() {
--exclude-distributed-tests \
--exclude-torch-export-tests \
--exclude-aot-dispatch-tests \
--exclude-quantization-tests \
--shard "$1" "$NUM_TEST_SHARDS" \
--verbose \
--upload-artifacts-while-running
@ -1156,6 +1147,12 @@ test_distributed() {
fi
}
test_quantization() {
echo "Testing quantization"
python test/test_quantization.py
}
test_rpc() {
echo "Testing RPC C++ tests"
# NB: the ending test_rpc must match the current function name for the current
@ -1582,7 +1579,6 @@ test_linux_aarch64() {
python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \
test_transformers test_multiprocessing test_numpy_interop test_autograd test_binary_ufuncs test_complex test_spectral_ops \
test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops test_ops \
distributed/elastic/timer/api_test distributed/elastic/timer/local_timer_example distributed/elastic/timer/local_timer_test \
--shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose
# Dynamo tests
@ -1657,6 +1653,8 @@ elif [[ "${TEST_CONFIG}" == *executorch* ]]; then
test_executorch
elif [[ "$TEST_CONFIG" == 'jit_legacy' ]]; then
test_python_legacy_jit
elif [[ "$TEST_CONFIG" == 'quantization' ]]; then
test_quantization
elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
# TODO: run some C++ tests
echo "no-op at the moment"

View File

@ -25,7 +25,7 @@ echo Copying over test times file
robocopy /E "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.additional_ci_files" "%PROJECT_DIR_WIN%\.additional_ci_files"
echo Run nn tests
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
python run_test.py --exclude-jit-executor --exclude-distributed-tests --exclude-quantization-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
if ERRORLEVEL 1 goto fail
popd

View File

@ -1 +1 @@
d119fc86140785e7efc8f125c17153544d1e0f20
973c9d01da863cac9c51e8a5c0d390fc84b84fbc

View File

@ -19,6 +19,7 @@ ciflow_push_tags:
- ciflow/nightly
- ciflow/periodic
- ciflow/periodic-rocm-mi300
- ciflow/quantization-periodic
- ciflow/rocm
- ciflow/rocm-mi300
- ciflow/s390

View File

@ -75,6 +75,10 @@ jobs:
runner: ${{ inputs.runner_prefix }}linux.2xlarge
# It takes less than 30m to finish python docs unless there are issues
timeout-minutes: 30
- docs_type: functorch
runner: ${{ inputs.runner_prefix }}linux.2xlarge
# It takes less than 15m to finish functorch docs unless there are issues
timeout-minutes: 15
# Set a fixed name for this job instead of using the current matrix-generated name, i.e. build-docs (cpp, linux.12xlarge, 180)
# The current name requires updating the database last docs push query from test-infra every time the matrix is updated
name: build-docs-${{ matrix.docs_type }}-${{ inputs.push }}
@ -207,6 +211,16 @@ jobs:
path: cppdocs/
s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/cppdocs
- name: Upload functorch Docs Preview
uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0
if: ${{ github.event_name == 'pull_request' && matrix.docs_type == 'functorch' && steps.build-docs.outcome == 'success' }}
with:
retention-days: 14
s3-bucket: doc-previews
if-no-files-found: error
path: functorch_ghpages/nightly/
s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/functorchdocs
- name: Teardown Linux
uses: pytorch/test-infra/.github/actions/teardown-linux@main
if: always()

View File

@ -169,7 +169,7 @@ jobs:
id: install-nvidia-driver
uses: pytorch/test-infra/.github/actions/setup-nvidia@main
with:
driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '580.82.07' }}
driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }}
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }}
- name: Setup GPU_FLAG for docker run

View File

@ -62,11 +62,6 @@ on:
required: false
type: number
default: 1
secrets:
HUGGING_FACE_HUB_TOKEN:
required: false
description: |
HF Auth token to avoid rate limits when downloading models or datasets from hub
env:
GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
@ -81,9 +76,10 @@ jobs:
strategy:
matrix: ${{ fromJSON(inputs.test-matrix) }}
fail-fast: false
runs-on: ${{ matrix.runner }}
timeout-minutes: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }}
runs-on: ${{ matrix.runner }}
steps:
# [see note: pytorch repo ref]
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
with:
@ -135,9 +131,6 @@ jobs:
- name: Start monitoring script
id: monitor-script
if: ${{ !inputs.disable-monitor }}
shell: bash
continue-on-error: true
env:
JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
@ -145,6 +138,9 @@ jobs:
WORKFLOW_RUN_ID: ${{github.run_id}}
MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }}
MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }}
if: ${{ !inputs.disable-monitor }}
shell: bash
continue-on-error: true
run: |
python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7
python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 &
@ -182,12 +178,6 @@ jobs:
run: |
echo "timeout=$((JOB_TIMEOUT-30))" >> "${GITHUB_OUTPUT}"
- name: Preserve github env variables for use in docker
shell: bash
run: |
env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}"
env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}"
- name: Test
id: test
env:
@ -203,22 +193,20 @@ jobs:
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
BRANCH: ${{ steps.parse-ref.outputs.branch }}
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }}
TEST_CONFIG: ${{ matrix.config }}
SHARD_NUMBER: ${{ matrix.shard }}
NUM_TEST_SHARDS: ${{ matrix.num_shards }}
REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }}
CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }}
VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }}
TEST_SHOWLOCALS: ${{ steps.keep-going.outputs.ci-test-showlocals }}
NO_TEST_TIMEOUT: ${{ steps.keep-going.outputs.ci-no-test-timeout }}
NO_TD: ${{ steps.keep-going.outputs.ci-no-td }}
TEST_CONFIG: ${{ matrix.config }}
SHARD_NUMBER: ${{ matrix.shard }}
NUM_TEST_SHARDS: ${{ matrix.num_shards }}
REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }}
DOCKER_IMAGE: ${{ inputs.docker-image }}
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }}
DASHBOARD_TAG: ${{ inputs.dashboard-tag }}
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
run: |
set -x
@ -248,7 +236,6 @@ jobs:
-e GITHUB_RUN_ATTEMPT \
-e JOB_ID \
-e JOB_NAME \
-e BASE_SHA \
-e BRANCH \
-e SHA1 \
-e AWS_DEFAULT_REGION \
@ -266,12 +253,10 @@ jobs:
-e PYTORCH_TEST_CUDA_MEM_LEAK_CHECK \
-e PYTORCH_TEST_RERUN_DISABLED_TESTS \
-e TESTS_TO_INCLUDE \
-e HUGGING_FACE_HUB_TOKEN \
-e DASHBOARD_TAG \
--env-file="${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" \
--ulimit stack=10485760:83886080 \
--ulimit core=0 \
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
--security-opt seccomp=unconfined \
--cap-add=SYS_PTRACE \
--shm-size="8g" \

View File

@ -178,12 +178,12 @@ jobs:
contents: read
container:
image: continuumio/miniconda3:4.12.0
environment: ${{ ((github.event_name == 'push' && github.event.ref == 'refs/heads/main') || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && 'nightly-wheel-upload' || '' }}
environment: ${{ (github.event_name == 'push' && github.event.ref == 'refs/heads/main') && 'nightly-wheel-upload' || '' }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Configure AWS credentials(PyTorch account) for main
if: ${{ (github.event_name == 'push' && github.event.ref == 'refs/heads/main') || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}
if: ${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' }}
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::749337293305:role/gha_workflow_nightly_build_wheels

View File

@ -43,11 +43,6 @@ on:
required: false
type: boolean
default: false
freezing:
description: Run freezing?
required: false
type: boolean
default: true
benchmark_configs:
description: The list of configs used the benchmark
required: false
@ -107,7 +102,7 @@ jobs:
if: github.event.schedule == '0 7 * * *'
with:
build-environment: linux-jammy-py3.10-gcc11-build
dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true-freezing-true
dashboard-tag: training-false-inference-true-default-true-dynamic-true-cppwrapper-true-aotinductor-true
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}
timeout-minutes: 720
@ -121,9 +116,10 @@ jobs:
name: inductor-test
uses: ./.github/workflows/_linux-test.yml
needs: inductor-build
if: github.event_name == 'workflow_dispatch'
with:
build-environment: linux-jammy-py3.10-gcc11-build
dashboard-tag: training-${{ inputs.training || 'false' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'true' }}-aotinductor-${{ inputs.aotinductor || 'true' }}-freezing-${{ inputs.freezing || 'true' }}
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}
docker-image: ${{ needs.inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-build.outputs.test-matrix }}
timeout-minutes: 720

View File

@ -105,7 +105,7 @@ jobs:
# NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout
# to run git rev-parse HEAD~:.ci/docker when a new image is needed
fetch-depth: 0
submodules: true
submodules: false
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"

View File

@ -127,8 +127,6 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
# More memory is needed to build with asan
runner: linux.2xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.10-clang18-asan
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan

View File

@ -0,0 +1,54 @@
name: quantization-periodic
on:
push:
tags:
- ciflow/quantization-periodic/*
workflow_dispatch:
schedule:
# run weekly
- cron: "45 0 * * 0"
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-default-label-prefix:
name: get-default-label-prefix
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
periodic-quantization-build:
name: periodic-quantization-build
uses: ./.github/workflows/_linux-build.yml
needs: get-default-label-prefix
with:
runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-cudnn9-py3-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '8.9'
test-matrix: |
{ include: [
{ config: "quantization", shard: 1, num_shards: 1, runner: "${{ needs.get-default-label-prefix.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
]}
secrets: inherit
periodic-test-quantization:
name: periodic-test-quantization
uses: ./.github/workflows/_linux-test.yml
needs: periodic-quantization-build
with:
build-environment: linux-jammy-cuda12.8-cudnn9-py3-gcc11
docker-image: ${{ needs.periodic-quantization-build.outputs.docker-image }}
test-matrix: ${{ needs.periodic-quantization-build.outputs.test-matrix }}
secrets: inherit

View File

@ -140,8 +140,6 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
# More memory is needed to build with asan
runner: linux.2xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.10-clang18-asan
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan

View File

@ -891,7 +891,7 @@ IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
endif()
# Set USE_FBGEMM_GENAI to ON for CUDA build on SM100.
if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8 AND NOT WIN32)
if(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
message(STATUS "Setting USE_FBGEMM_GENAI to ON, doing CUDA build for SM100a")
set(USE_FBGEMM_GENAI ON)
endif()

View File

@ -65,24 +65,14 @@ DLDataType getDLDataType(const Tensor& t) {
break;
// TODO(#146647): use macro here instead of spelling out each shell dtype
case ScalarType::Float8_e5m2:
dtype.code = DLDataTypeCode::kDLFloat8_e5m2;
break;
case ScalarType::Float8_e5m2fnuz:
dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz;
break;
case ScalarType::Float8_e4m3fn:
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn;
break;
case ScalarType::Float8_e4m3fnuz:
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz;
break;
case ScalarType::Float8_e8m0fnu:
dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu;
TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack");
break;
case ScalarType::Float4_e2m1fn_x2:
dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
dtype.lanes = 2;
dtype.bits = 4;
TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack");
break;
case ScalarType::QInt8:
case ScalarType::QUInt8:
@ -187,11 +177,7 @@ static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* dat
ScalarType toScalarType(const DLDataType& dtype) {
ScalarType stype = ScalarType::Undefined;
if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
TORCH_CHECK_BUFFER(
dtype.lanes == 1,
"ATen does not support lanes != 1 for dtype code", std::to_string(dtype.code));
}
TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1");
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
@ -283,73 +269,6 @@ ScalarType toScalarType(const DLDataType& dtype) {
false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e5m2:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e5m2;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e5m2 bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e5m2fnuz:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e5m2fnuz;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e5m2fnuz bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e4m3fn:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e4m3fn;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e4m3fn bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e4m3fnuz:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e4m3fnuz;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e4m3fnuz bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat8_e8m0fnu:
switch (dtype.bits) {
case 8:
stype = ScalarType::Float8_e8m0fnu;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat8_e8m0fnu bits ", std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat4_e2m1fn:
switch (dtype.bits) {
case 4:
switch (dtype.lanes) {
case 2:
stype = ScalarType::Float4_e2m1fn_x2;
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat4_e2m1fn lanes ", std::to_string(dtype.lanes));
}
break;
default:
TORCH_CHECK_BUFFER(
false, "Unsupported kDLFloat4_e2m1fn bits ", std::to_string(dtype.bits));
}
break;
default:
TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code));
}
@ -435,8 +354,8 @@ T* toDLPackImpl(const Tensor& src) {
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
fillVersion(&atDLMTensor->tensor);

View File

@ -102,7 +102,7 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
// SparseTensorImpl has no storage, so we cannot query its nbytes.
// (original_storage_size is only used for storage resizing in fsdp anyway, which does not apply to sparse)
// Same for XLA
if (base.unsafeGetTensorImpl()->has_storage() && data_ptr().device().type() != c10::DeviceType::XLA) {
if (base.unsafeGetTensorImpl()->has_storage() && base.device().type() != c10::DeviceType::XLA) {
original_storage_size_ = base.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
} else {
original_storage_size_ = -1;

View File

@ -266,14 +266,11 @@ CUDAGeneratorImpl::CUDAGeneratorImpl(
* See Note [Acquire lock when using random generators]
*/
void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
state_->seed_ = seed;
state_->philox_offset_per_thread_ = 0;
no_reset_rnn_state_.clear();
} else {
TORCH_CHECK(state_->seed_ == seed, "CUDAGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.");
// no-op case
}
at::cuda::assertNotCapturing(
"Cannot call CUDAGeneratorImpl::set_current_seed");
state_->seed_ = seed;
state_->philox_offset_per_thread_ = 0;
no_reset_rnn_state_.clear();
}
/**
@ -302,6 +299,9 @@ uint64_t CUDAGeneratorImpl::get_offset() const {
* Gets the current seed of CUDAGeneratorImpl.
*/
uint64_t CUDAGeneratorImpl::current_seed() const {
// Debatable if current_seed() should be allowed in captured regions.
// Conservatively disallow it for now.
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::current_seed");
return state_->seed_;
}
@ -346,6 +346,8 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
* and size of the internal state.
*/
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
at::cuda::assertNotCapturing(
"Please ensure to utilize the CUDAGeneratorImpl::set_state_index method during capturing.");
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
@ -400,27 +402,15 @@ c10::intrusive_ptr<c10::GeneratorImpl> CUDAGeneratorImpl::graphsafe_get_state()
*/
void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) {
// see Note [Why enforce RNG offset % 4 == 0?]
// Note: If you use CUDNN RNN's, calling
// set_philox_offset_per_thread instead of set_offset will cause the
// cudnn RNN rng state to become stale.
TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4");
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
state_->philox_offset_per_thread_ = offset;
} else {
state_->offset_intragraph_ = offset;
}
state_->philox_offset_per_thread_ = offset;
}
/**
* Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl.
*/
uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const {
if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) {
return state_->philox_offset_per_thread_;
} else {
return state_->offset_intragraph_;
}
return state_->philox_offset_per_thread_;
}
/**

View File

@ -19,7 +19,7 @@
#define DLPACK_MAJOR_VERSION 1
/*! \brief The current minor version of dlpack */
#define DLPACK_MINOR_VERSION 1
#define DLPACK_MINOR_VERSION 0
/*! \brief DLPACK_DLL prefix for windows */
#ifdef _WIN32
@ -32,7 +32,9 @@
#define DLPACK_DLL
#endif
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <stdint.h>
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <stddef.h>
#ifdef __cplusplus
@ -157,26 +159,6 @@ typedef enum {
kDLComplex = 5U,
/*! \brief boolean */
kDLBool = 6U,
/*! \brief FP8 data types */
kDLFloat8_e3m4 = 7U,
kDLFloat8_e4m3 = 8U,
kDLFloat8_e4m3b11fnuz = 9U,
kDLFloat8_e4m3fn = 10U,
kDLFloat8_e4m3fnuz = 11U,
kDLFloat8_e5m2 = 12U,
kDLFloat8_e5m2fnuz = 13U,
kDLFloat8_e8m0fnu = 14U,
/*! \brief FP6 data types
* Setting bits != 6 is currently unspecified, and the producer must ensure it is set
* while the consumer must stop importing if the value is unexpected.
*/
kDLFloat6_e2m3fn = 15U,
kDLFloat6_e3m2fn = 16U,
/*! \brief FP4 data types
* Setting bits != 4 is currently unspecified, and the producer must ensure it is set
* while the consumer must stop importing if the value is unexpected.
*/
kDLFloat4_e2m1fn = 17U,
} DLDataTypeCode;
/*!
@ -190,12 +172,6 @@ typedef enum {
* - int8: type_code = 0, bits = 8, lanes = 1
* - std::complex<float>: type_code = 5, bits = 64, lanes = 1
* - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits)
* - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory)
* - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory)
* - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory)
*
* When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e.,
* for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element.
*/
typedef struct {
/*!
@ -253,12 +229,12 @@ typedef struct {
/*! \brief The data type of the pointer*/
DLDataType dtype;
/*! \brief The shape of the tensor */
int64_t* shape;
const int64_t* shape;
/*!
* \brief strides of the tensor (in number of elements, not bytes)
* can be NULL, indicating tensor is compact and row-majored.
*/
int64_t* strides;
const int64_t* strides;
/*! \brief The offset in bytes to the beginning pointer to data */
uint64_t byte_offset;
} DLTensor;
@ -293,7 +269,7 @@ typedef struct DLManagedTensor {
void (*deleter)(struct DLManagedTensor * self);
} DLManagedTensor;
// bit masks used in the DLManagedTensorVersioned
// bit masks used in in the DLManagedTensorVersioned
/*! \brief bit mask to indicate that the tensor is read only. */
#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)
@ -306,14 +282,6 @@ typedef struct DLManagedTensor {
*/
#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)
/*
* \brief bit mask to indicate that whether a sub-byte type is packed or padded.
*
* The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can
* be set by the producer to signal that a tensor of sub-byte type is padded.
*/
#define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL)
/*!
* \brief A versioned and managed C Tensor object, manage memory of DLTensor.
*

View File

@ -171,8 +171,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
POINTWISE_BOXED(fill_.Scalar);
POINTWISE_BOXED(zero_);
// This is special because this op doesn't return anything
m.impl("_assert_tensor_metadata", native::_assert_tensor_metadata);
#undef UNARY_POINTWISE
#undef UNARY_POINTWISE_ALL

View File

@ -81,7 +81,7 @@ Tensor math_channel_shuffle(const Tensor& self, int64_t groups) {
// TODO: contiguous can be made to preserve the memory format
// of the input. However since the above reshape clobbers h and w
// it may not be safe to do that, since channels_last contiguous
// may think oc and the last dim correspond to h,w?
// may think oc and and the last dim correspond to h,w?
// It is not clear, however from initial looking around it feels that
// this may not be correct.
// In this case channels last will likely require custom implementation

View File

@ -67,13 +67,13 @@ TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
int64_t inputH = input_.size(heightDim);
int64_t inputW = input_.size(widthDim);
TORCH_CHECK((poolSizeT <= inputT) && (outputT + poolSizeT - 1 < inputT),
TORCH_CHECK(outputT + poolSizeT - 1 < inputT,
"fractional_max_pool3d_out(): pool time ", poolSizeT,
" too large relative to input time ", inputT);
TORCH_CHECK((poolSizeW <= inputW) && (outputW + poolSizeW - 1 < inputW),
TORCH_CHECK(outputW + poolSizeW - 1 < inputW,
"fractional_max_pool3d_out(): pool width ", poolSizeW,
" too large relative to input width ", inputW);
TORCH_CHECK((poolSizeH <= inputH) && (outputH + poolSizeH - 1 < inputH),
TORCH_CHECK(outputH + poolSizeH - 1 < inputH,
"fractional_max_pool3d_out(): pool height ", poolSizeH,
" too large relative to input height ", inputH);

View File

@ -52,7 +52,6 @@ void apply_triu_tril_single(
int64_t self_col_stride,
bool upper) {
constexpr int64_t zero = 0;
k = std::clamp(k, -n, m); // Clamp k to [-n, m] to prevent i + k arithmetic overflow, especially if k approaches INT64_MAX/INT64_MIN.
if (upper) {
parallel_for(0, n, 0, [&](int64_t start, int64_t end) {

View File

@ -85,11 +85,11 @@ void cpu_max_unpool(
if constexpr (is_3d) {
TORCH_CHECK(false, "Found an invalid max index: ", optional_error_index.value(),
" (output volumes are of size ", output_depth,
"x", output_height, "x", output_width, ")");
"x", output_height, "x", output_width);
} else {
TORCH_CHECK(false, "Found an invalid max index: ", optional_error_index.value(),
" (output volumes are of size ", output_height,
"x", output_width, ")");
"x", output_width);
}
}

View File

@ -416,7 +416,6 @@ struct ReduceOp {
if (config.should_block_y_reduce()) {
value = block_y_reduce<output_vec_size>(value, shared_memory);
}
__syncthreads();
if (config.should_block_x_reduce()) {
value = block_x_reduce<output_vec_size>(value, shared_memory);
}

View File

@ -17,11 +17,12 @@ __global__ static void compute_cuda_kernel(
index_t* result_ptr,
int64_t size,
int64_t result_size) {
CUDA_KERNEL_ASSERT_PRINTF(
result_size == cumsum_ptr[size - 1],
if (C10_UNLIKELY((result_size != cumsum_ptr[size - 1]))) {
printf("%s:%d:%s: block: [%d,%d,%d], thread: [%d,%d,%d] "
"Invalid input! In `repeat_interleave`, the `output_size` argument (%ld) must be the same as the sum of the elements in the `repeats` tensor (%ld).\n",
result_size,
cumsum_ptr[size - 1]);
__FILE__, __LINE__, __func__,blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, result_size, cumsum_ptr[size - 1 ]);
CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1])
}
int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;

View File

@ -5,20 +5,12 @@
namespace at::native {
__global__ void weight_int8pack_mm_kernel(
const float* x,
const int8_t* w,
const float* scale,
float* out,
int B,
int K,
int N) {
__global__ void weight_int8pack_mm_kernel(const float* x, const int8_t* w, const float* scale, float* out, int B, int K, int N) {
// one thread per output element: [B, N]
int b = blockIdx.y * blockDim.y + threadIdx.y;
int n = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B || n >= N)
return;
if (b >= B || n >= N) return;
float acc = 0.0f;
for (int k = 0; k < K; ++k) {
@ -28,11 +20,7 @@ __global__ void weight_int8pack_mm_kernel(
out[b * N + n] = acc * scale[n];
}
void launch_weight_int8pack_mm_cuda_kernel(
const Tensor& x,
const Tensor& w_int8,
const Tensor& scale,
Tensor& out) {
void launch_weight_int8pack_mm_cuda_kernel(const Tensor& x, const Tensor& w_int8, const Tensor& scale, Tensor& out) {
const int B = x.size(0);
const int K = x.size(1);
const int N = w_int8.size(0);
@ -47,16 +35,12 @@ void launch_weight_int8pack_mm_cuda_kernel(
w_int8.data_ptr<int8_t>(),
scale.data_ptr<float>(),
out.data_ptr<float>(),
B,
K,
N);
B, K, N);
}
// Main GPU entry point
at::Tensor _weight_int8pack_mm_cuda(
const at::Tensor& x,
const at::Tensor& w_int8,
const at::Tensor& scale) {
at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int8, const at::Tensor& scale) {
// --- Check inputs ---
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(w_int8.is_cuda(), "w must be a CUDA tensor");
@ -66,16 +50,12 @@ at::Tensor _weight_int8pack_mm_cuda(
TORCH_CHECK(w_int8.dim() == 2, "w must be 2D");
TORCH_CHECK(scale.dim() == 1, "scale must be 1D");
TORCH_CHECK(
x.size(1) == w_int8.size(1),
"K dimension mismatch: x.size(1) != w.size(1)");
TORCH_CHECK(
w_int8.size(0) == scale.size(0),
"Output dim mismatch: w.size(0) != scale.size(0)");
TORCH_CHECK(x.size(1) == w_int8.size(1), "K dimension mismatch: x.size(1) != w.size(1)");
TORCH_CHECK(w_int8.size(0) == scale.size(0), "Output dim mismatch: w.size(0) != scale.size(0)");
// --- Determine shapes ---
auto B = x.size(0); // batch size
auto N = w_int8.size(0); // output dim
auto B = x.size(0); // batch size
auto N = w_int8.size(0); // output dim
// Ensure inputs are in the correct types for the kernel
auto x_f32 = x.to(at::kFloat);
@ -83,13 +63,12 @@ at::Tensor _weight_int8pack_mm_cuda(
auto scale_f32 = scale.to(at::kFloat);
// --- Allocate output ---
auto out = at::empty({B, N}, x_f32.options());
auto out = at::empty({B, N}, x.options().dtype(at::kFloat));
// --- Launch kernel ---
launch_weight_int8pack_mm_cuda_kernel(
x_f32, w_int8_contiguous, scale_f32, out);
launch_weight_int8pack_mm_cuda_kernel(x_f32, w_int8_contiguous, scale_f32, out);
return out.to(x.dtype());
return out;
}
} // namespace at::native

View File

@ -482,9 +482,7 @@ auto build_graph(
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA")
.set_is_inference(return_softmaxstats == false)
// TODO(eqy): switch to this API once cuDNN FE is upgraded
// .set_generate_stats(return_softmaxstats)
.set_generate_stats(return_softmaxstats)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
@ -704,9 +702,7 @@ auto build_graph_nestedtensor(
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA_NESTEDTENSOR")
.set_is_inference(return_softmaxstats == false)
// TODO(eqy): switch to this API once cuDNN FE is upgraded
// .set_generate_stats(return_softmaxstats)
.set_generate_stats(return_softmaxstats)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale)
.set_seq_len_q(SEQ_LEN_Q_)

View File

@ -1770,12 +1770,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> miopen_depthwise_convolution_back
// fusions
// ---------------------------------------------------------------------
void raw_miopen_convolution_add_relu_out(
void raw_miopen_convolution_relu_out(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const Tensor& z,
float alpha,
const Tensor& bias,
IntArrayRef stride,
IntArrayRef padding,
@ -1783,20 +1781,68 @@ void raw_miopen_convolution_add_relu_out(
int64_t groups,
bool benchmark,
bool deterministic) {
raw_miopen_convolution_forward_out(
output,
auto dataType = getMiopenDataType(input);
miopenConvolutionMode_t c_mode = miopenConvolution;
ConvolutionArgs args{ input, output, weight };
args.handle = getMiopenHandle();
at::MemoryFormat memory_format = miopen_conv_suggest_memory_format(input, weight);
setConvolutionParams(
&args.params,
args.handle,
input,
weight,
padding,
stride,
dilation,
groups,
deterministic,
memory_format);
args.idesc.set(input, memory_format);
args.wdesc.set(weight, memory_format, 0);
args.odesc.set(output, memory_format);
args.cdesc.set(
dataType,
c_mode,
input.dim() - 2,
args.params.padding,
args.params.stride,
args.params.dilation,
args.params.groups,
benchmark,
deterministic);
at::Tensor alpha_mul_z_add_bias =
at::native::reshape_bias(input.dim(), bias).add(z, alpha);
output.add_(alpha_mul_z_add_bias);
output.relu_();
TensorDescriptor bdesc;
bdesc.set(bias.expand({1, bias.size(0)}), output.dim());
// Create the fusion plan
miopenFusionPlanDescriptor_t fusePlanDesc;
miopenFusionOpDescriptor_t convoOp;
miopenFusionOpDescriptor_t biasOp;
miopenFusionOpDescriptor_t activOp;
MIOPEN_CHECK(miopenCreateFusionPlan(&fusePlanDesc, miopenVerticalFusion, args.idesc.desc()));
MIOPEN_CHECK(miopenCreateOpConvForward(fusePlanDesc, &convoOp, args.cdesc.desc(), args.wdesc.desc()));
MIOPEN_CHECK(miopenCreateOpBiasForward(fusePlanDesc, &biasOp, bdesc.desc()));
MIOPEN_CHECK(miopenCreateOpActivationForward(fusePlanDesc, &activOp, miopenActivationRELU));
// compile fusion plan
MIOPEN_CHECK(miopenCompileFusionPlan(args.handle, fusePlanDesc));
// Set the Args
float alpha = static_cast<float>(1);
float beta = static_cast<float>(0);
float activ_alpha = static_cast<float>(0);
float activ_beta = static_cast<float>(0);
float activ_gamma = static_cast<float>(0);
miopenOperatorArgs_t fusionArgs;
MIOPEN_CHECK(miopenCreateOperatorArgs(&fusionArgs));
MIOPEN_CHECK(miopenSetOpArgsConvForward(fusionArgs, convoOp, &alpha, &beta, weight.const_data_ptr()));
MIOPEN_CHECK(miopenSetOpArgsBiasForward(fusionArgs, biasOp, &alpha, &beta, bias.const_data_ptr()));
MIOPEN_CHECK(miopenSetOpArgsActivForward(fusionArgs, activOp, &alpha, &beta, activ_alpha, activ_beta, activ_gamma));
miopenExecuteFusionPlan(args.handle, fusePlanDesc, args.idesc.desc(), input.const_data_ptr(), args.odesc.desc(), output.data_ptr(), fusionArgs);
// Cleanup
miopenDestroyFusionPlan(fusePlanDesc);
}
static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat memory_format) {
@ -1809,107 +1855,171 @@ static at::Tensor self_or_new_memory_format(at::Tensor& self, at::MemoryFormat m
Tensor miopen_convolution_add_relu(
const Tensor& input_t,
const Tensor& weight_t,
const Tensor& z_t,
const Tensor& z,
const std::optional<Scalar>& alpha,
const std::optional<Tensor>& bias_t,
const std::optional<Tensor>& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
int64_t groups) {
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
const Tensor input = input_t.contiguous(memory_format);
const Tensor weight = weight_t.contiguous(memory_format);
Tensor z = z_t;
if (z.suggest_memory_format() != memory_format) {
z = z.to(memory_format);
}
z = z.contiguous(memory_format);
// FuseFrozenConvAddRelu performs some tensor shape checking
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input.sizes(), weight.sizes(), padding, stride, dilation),
input.options().memory_format(memory_format));
if (output_t.numel() == 0) {
return output_t;
}
// MIOpen does not support fusion of add, the alpha2 * z step of the below cuDNN function:
// y = act ( alpha1 * conv(x) + alpha2 * z + bias )
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
auto& ctx = at::globalContext();
bool benchmark = ctx.benchmarkCuDNN();
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
output_t.options().device_opt(),
output_t.options().pinned_memory_opt());
raw_miopen_convolution_add_relu_out(
output_t,
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 };
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
input_t.options().memory_format(memory_format));
if (output_t.numel() == 0){
return output_t;
}
// Avoid ambiguity of "output" when this is being used as backwards
TensorArg output{output_t, "result", 0};
miopen_convolution_forward_out(
output,
"miopen_convolution_add_relu",
input,
weight,
z,
_alpha,
_bias,
stride,
padding,
stride,
dilation,
groups,
benchmark,
true); // deterministic
false // deterministic
);
return output_t;
auto contig_output_t = self_or_new_memory_format(output_t, memory_format);
if (!output_t.is_same(contig_output_t)) {
contig_output_t.copy_(output_t);
}
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias.has_value()
? bias.value()
: at::zeros(
{contig_output_t.size(1)},
optTypeMetaToScalarType(contig_output_t.options().dtype_opt()),
contig_output_t.options().layout_opt(),
contig_output_t.options().device_opt(),
contig_output_t.options().pinned_memory_opt());
at::Tensor alpha_mul_z_add_bias = at::native::reshape_bias(input_t.dim(), _bias).add(z, _alpha);
contig_output_t.add_(alpha_mul_z_add_bias);
contig_output_t.relu_();
return contig_output_t;
}
Tensor miopen_convolution_relu(
const Tensor& input_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_t,
const std::optional<Tensor>& bias,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
int64_t groups) {
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
const Tensor input = input_t.contiguous(memory_format);
const Tensor weight = weight_t.contiguous(memory_format);
// FuseFrozenConvAddRelu performs some tensor shape checking
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input.sizes(), weight.sizes(), padding, stride, dilation),
input.options().memory_format(memory_format));
if (output_t.numel() == 0) {
return output_t;
}
auto& ctx = at::globalContext();
bool benchmark = ctx.benchmarkCuDNN();
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
output_t.options().device_opt(),
output_t.options().pinned_memory_opt());
raw_miopen_convolution_add_relu_out(
output_t,
input,
weight,
output_t, // use output_t as z to satisfy MIOpen API
0, // alpha
_bias,
stride,
padding,
dilation,
groups,
benchmark, // benchmark
true); // deterministic
// MIOpen currently only supports MemoryFormat::Contiguous and fp32 and 2d
if (input_t.suggest_memory_format() == at::MemoryFormat::Contiguous
&& input_t.scalar_type() == at::kFloat
&& input_t.ndimension() == 4) {
return output_t;
// FuseFrozenConvAddRelu performs some tensor shape checking
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
input_t.options().memory_format(input_t.suggest_memory_format()));
if (output_t.numel() == 0) {
return output_t;
}
auto _bias = bias.has_value()
? bias.value()
: at::zeros(
{output_t.size(1)},
optTypeMetaToScalarType(output_t.options().dtype_opt()),
output_t.options().layout_opt(),
output_t.options().device_opt(),
output_t.options().pinned_memory_opt());
raw_miopen_convolution_relu_out(
output_t,
input_t,
weight_t,
_bias,
stride,
padding,
dilation,
groups,
benchmark, // benchmark
false // deterministic
);
return output_t;
}
else {
// fallback
auto memory_format = miopen_conv_suggest_memory_format(input_t, weight_t);
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 };
Tensor output_t = at::detail::empty_cuda(
conv_output_size(
input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
input->options().memory_format(memory_format));
if (output_t.numel() == 0){
return output_t;
}
// Avoid ambiguity of "output" when this is being used as backwards
TensorArg output{output_t, "result", 0};
miopen_convolution_forward_out(
output,
"miopen_convolution_relu",
input,
weight,
padding,
stride,
dilation,
groups,
benchmark,
false // deterministic
);
auto contig_output_t = self_or_new_memory_format(output_t, memory_format);
if (!output_t.is_same(contig_output_t)) {
contig_output_t.copy_(output_t);
}
auto _bias = bias.has_value()
? bias.value()
: at::zeros(
{contig_output_t.size(1)},
optTypeMetaToScalarType(contig_output_t.options().dtype_opt()),
contig_output_t.options().layout_opt(),
contig_output_t.options().device_opt(),
contig_output_t.options().pinned_memory_opt());
at::Tensor reshaped_bias = at::native::reshape_bias(input_t.dim(), _bias);
contig_output_t.add_(reshaped_bias);
contig_output_t.relu_();
return contig_output_t;
}
}
REGISTER_CUDA_DISPATCH(miopen_convolution_backward_stub, &miopen_convolution_backward)

View File

@ -568,7 +568,7 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
MPSShape* mpsStrides = getMPSShape(_tensor.strides());
check_mps_shape(mpsShape);
auto storage_numel = src.storage().nbytes() / src.element_size() - src.storage_offset();
auto storage_numel = src.storage().nbytes() / src.element_size();
TORCH_CHECK(storage_numel <= std::numeric_limits<int32_t>::max(),
"MPSGaph does not support tensor dims larger than INT_MAX");
MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType

View File

@ -4372,7 +4372,7 @@
variants: function, method
dispatch:
CPU: narrow_copy_dense_cpu
SparseCPU, SparseCUDA, SparseMPS: narrow_copy_sparse
SparseCPU, SparseCUDA: narrow_copy_sparse
CompositeExplicitAutogradNonFunctional: narrow_copy_dense_symint
tags: view_copy
@ -6660,7 +6660,7 @@
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CompositeExplicitAutograd: zeros_out
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: zeros_sparse_out
SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out
- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
dispatch:
@ -10699,7 +10699,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow
CUDA: foreach_tensor_div_list_kernel_cuda
MTIA: foreach_tensor_div_list_kernel_mtia
- func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10707,7 +10706,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow_
CUDA: foreach_tensor_div_list_kernel_cuda_
MTIA: foreach_tensor_div_list_kernel_mtia_
autogen: _foreach_div.List_out
- func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
@ -10731,7 +10729,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow
CUDA: foreach_tensor_div_tensor_kernel_cuda
MTIA: foreach_tensor_div_tensor_kernel_mtia
- func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10739,7 +10736,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow_
CUDA: foreach_tensor_div_tensor_kernel_cuda_
MTIA: foreach_tensor_div_tensor_kernel_mtia_
autogen: _foreach_div.Tensor_out
- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]

View File

@ -1,6 +1,5 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ceil_div.h>
#include <ATen/native/cuda/Loops.cuh>
#include <c10/cuda/CUDAGuard.h>
@ -22,11 +21,10 @@
namespace at::native {
namespace {
template <typename T>
__global__ void ChooseQuantizationParamsKernelImpl(
const int64_t* fake_quant_on,
const T* x_min,
const T* x_max,
const float* x_min,
const float* x_max,
int32_t qmin,
int32_t qmax,
int size,
@ -95,44 +93,34 @@ __global__ void ChooseQuantizationParamsKernelImpl(
}
}
__device__ inline bool isinf_device(float v) {
return ::isinf(v);
}
__device__ inline bool isinf_device(c10::BFloat16 v) {
return ::isinf(static_cast<float>(v));
}
// CUDA kernel to compute Moving Average Min/Max of the tensor.
// It uses the running_min and running_max along with averaging const, c.
// The formula used to compute the new min/max is as follows
//
// running_min = (1 - c) * running_min + c * x_min, if running_min != inf
// running_min = x_min, if running_min == inf
template <typename T>
__global__ void MovingAverageMinMax(
const int64_t* observer_on,
const T* x_min,
const T* x_max,
T* running_min,
T* running_max,
const float* x_min,
const float* x_max,
float* running_min,
float* running_max,
const float averaging_const,
const int size) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (*observer_on == 1) {
if (i < size) {
T curr_min = x_min[i];
T curr_max = x_max[i];
float curr_min = x_min[i];
float curr_max = x_max[i];
T averaging_const_t = static_cast<T>(averaging_const);
float adjusted_min = ::isinf(running_min[i])
? curr_min
: (running_min[i]) + averaging_const * (curr_min - (running_min[i]));
T adjusted_min = isinf_device(running_min[i]) ? curr_min
: (running_min[i]) +
averaging_const_t * (curr_min - (running_min[i]));
T adjusted_max = isinf_device(running_max[i]) ? curr_max
: (running_max[i]) +
averaging_const_t * (curr_max - (running_max[i]));
float adjusted_max = ::isinf(running_max[i])
? curr_max
: (running_max[i]) + averaging_const * (curr_max - (running_max[i]));
running_min[i] = adjusted_min;
running_max[i] = adjusted_max;
@ -154,51 +142,40 @@ void _calculate_moving_average(
at::Tensor x_min, x_max;
int64_t* observer_on_data = observer_on.data_ptr<int64_t>();
float* running_min_data = running_min.data_ptr<float>();
float* running_max_data = running_max.data_ptr<float>();
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
if (per_row_fq) {
std::tie(x_min, x_max) = at::aminmax(x, 1);
float* x_min_data = x_min.data_ptr<float>();
float* x_max_data = x_max.data_ptr<float>();
int num_threads = std::min(size, (int64_t)512);
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
// Moving Average Min/Max observer for activations
MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>(
observer_on_data,
x_min_data,
x_max_data,
running_min_data,
running_max_data,
averaging_const,
size);
});
// Moving Average Min/Max observer for activations
MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>(
observer_on_data,
x_min_data,
x_max_data,
running_min_data,
running_max_data,
averaging_const,
size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
std::tie(x_min, x_max) = at::aminmax(x);
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
// Moving Average Min/Max observer for activations
MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>(
observer_on_data,
x_min_data,
x_max_data,
running_min_data,
running_max_data,
averaging_const,
1 /*size*/);
});
float* x_min_data = x_min.data_ptr<float>();
float* x_max_data = x_max.data_ptr<float>();
// Moving Average Min/Max observer for activations
MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>(
observer_on_data,
x_min_data,
x_max_data,
running_min_data,
running_max_data,
averaging_const,
1 /*size*/);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
@ -221,44 +198,34 @@ void _calc_moving_avg_qparams_helper(
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
int64_t* fake_quant_on_data = fake_quant_on.data_ptr<int64_t>();
if (per_row_fq) {
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
int num_threads = std::min(size, (int64_t)512);
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
ChooseQuantizationParamsKernelImpl<<<
num_blocks,
num_threads,
0,
cuda_stream>>>(
fake_quant_on_data,
running_min_data,
running_max_data,
qmin,
qmax,
size,
symmetric_quant,
scale_ptr,
zp_ptr);
});
float* running_min_data = running_min.data_ptr<float>();
float* running_max_data = running_max.data_ptr<float>();
int num_threads = std::min(size, (int64_t)512);
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
ChooseQuantizationParamsKernelImpl<<<num_blocks, num_threads, 0, cuda_stream>>>(
fake_quant_on_data,
running_min_data,
running_max_data,
qmin,
qmax,
size,
symmetric_quant,
scale_ptr,
zp_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
fake_quant_on_data,
running_min_data,
running_max_data,
qmin,
qmax,
1, // size
symmetric_quant, // preserve_sparsity
scale_ptr,
zp_ptr);
});
float* running_min_data = running_min.data_ptr<float>();
float* running_max_data = running_max.data_ptr<float>();
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
fake_quant_on_data,
running_min_data,
running_max_data,
qmin,
qmax,
1, // size
symmetric_quant, // preserve_sparsity
scale_ptr,
zp_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}

View File

@ -42,7 +42,7 @@ TEST(MPSObjCInterfaceTest, MPSCustomKernel) {
id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: [NSString stringWithUTF8String:CUSTOM_KERNEL]
options: nil
error: &error];
TORCH_CHECK(customKernelLibrary, "Failed to create custom kernel library, error: ", error.localizedDescription.UTF8String);
TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String);
id<MTLFunction> customFunction = [customKernelLibrary newFunctionWithName: @"add_arrays"];
TORCH_CHECK(customFunction, "Failed to create function state object for the kernel");

View File

@ -76,23 +76,4 @@ int32_t getGlobalIdxFromDevice(DeviceIndex device) {
return device_global_idxs[device];
}
// Check if a device can access the memory of a peer device directly.
bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer) {
if (device == -1) {
device = c10::xpu::current_device();
}
if (peer == -1) {
peer = c10::xpu::current_device();
}
check_device_index(device);
check_device_index(peer);
// A device can always access itself
if (device == peer) {
return true;
}
return c10::xpu::get_raw_device(device).ext_oneapi_can_access_peer(
c10::xpu::get_raw_device(peer),
sycl::ext::oneapi::peer_access::access_supported);
}
} // namespace at::xpu

View File

@ -17,6 +17,4 @@ TORCH_XPU_API DeviceProp* getDeviceProperties(DeviceIndex device);
TORCH_XPU_API int32_t getGlobalIdxFromDevice(DeviceIndex device);
TORCH_XPU_API bool canDeviceAccessPeer(DeviceIndex device, DeviceIndex peer);
} // namespace at::xpu

View File

@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
YituTechConvBert,pass,0
meta-llama/Llama-3.2-1B,pass,5
google/gemma-2-2b,pass,5
google/gemma-3-4b-it,pass_due_to_skip,0
openai/whisper-tiny,pass,6
Qwen/Qwen3-0.6B,pass,5

1 name accuracy graph_breaks
171
172
173

View File

@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,5
YituTechConvBert,pass,5
meta-llama/Llama-3.2-1B,eager_fail_to_run,0
google/gemma-2-2b,eager_fail_to_run,0
google/gemma-3-4b-it,eager_fail_to_run,0
openai/whisper-tiny,eager_fail_to_run,0
Qwen/Qwen3-0.6B,eager_fail_to_run,0

1 name accuracy graph_breaks
171
172
173

View File

@ -167,23 +167,3 @@ XLNetLMHeadModel,pass,0
YituTechConvBert,pass,0
meta-llama/Llama-3.2-1B,fail_accuracy,0
google/gemma-2-2b,fail_accuracy,0
google/gemma-3-4b-it,fail_accuracy,0
openai/whisper-tiny,fail_to_run,0
Qwen/Qwen3-0.6B,fail_accuracy,0

1 name accuracy graph_breaks
167
168
169

View File

@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
YituTechConvBert,pass,0
meta-llama/Llama-3.2-1B,pass,5
google/gemma-2-2b,pass,5
google/gemma-3-4b-it,pass_due_to_skip,0
openai/whisper-tiny,pass,6
Qwen/Qwen3-0.6B,pass,5

1 name accuracy graph_breaks
171
172
173

View File

@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,5
YituTechConvBert,pass,5
meta-llama/Llama-3.2-1B,eager_fail_to_run,0
google/gemma-2-2b,eager_fail_to_run,0
google/gemma-3-4b-it,eager_fail_to_run,0
openai/whisper-tiny,eager_fail_to_run,0
Qwen/Qwen3-0.6B,eager_fail_to_run,0

1 name accuracy graph_breaks
171
172
173

View File

@ -205,7 +205,7 @@ llama,pass,0
llama_v2_7b_16h,pass_due_to_skip,0
llama_v2_7b_16h,model_fail_to_load,0

1 name accuracy graph_breaks
205
206
207
208
209
210
211

View File

@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
YituTechConvBert,pass,0
meta-llama/Llama-3.2-1B,pass,5
google/gemma-2-2b,pass,5
google/gemma-3-4b-it,pass,0
openai/whisper-tiny,pass,6
Qwen/Qwen3-0.6B,pass,5

1 name accuracy graph_breaks
171
172
173

View File

@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,5
YituTechConvBert,pass,5
meta-llama/Llama-3.2-1B,eager_fail_to_run,0
google/gemma-2-2b,eager_fail_to_run,0
google/gemma-3-4b-it,eager_fail_to_run,0
openai/whisper-tiny,eager_fail_to_run,0
Qwen/Qwen3-0.6B,eager_fail_to_run,0

1 name accuracy graph_breaks
171
172
173

View File

@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
YituTechConvBert,pass,0
meta-llama/Llama-3.2-1B,pass,5
google/gemma-2-2b,pass,5
google/gemma-3-4b-it,pass_due_to_skip,0
openai/whisper-tiny,pass,6
Qwen/Qwen3-0.6B,pass,5

1 name accuracy graph_breaks
171
172
173

View File

@ -205,7 +205,7 @@ llama,pass,0
llama_v2_7b_16h,pass_due_to_skip,0
llama_v2_7b_16h,model_fail_to_load,0

1 name accuracy graph_breaks
205
206
207
208
209
210
211

View File

@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,0
YituTechConvBert,pass,0
meta-llama/Llama-3.2-1B,pass,5
google/gemma-2-2b,pass,5
google/gemma-3-4b-it,pass_due_to_skip,0
openai/whisper-tiny,pass,6
Qwen/Qwen3-0.6B,pass,5

1 name accuracy graph_breaks
171
172
173

View File

@ -171,23 +171,3 @@ XLNetLMHeadModel,pass,5
YituTechConvBert,pass,5
meta-llama/Llama-3.2-1B,eager_fail_to_run,0
google/gemma-2-2b,eager_fail_to_run,0
google/gemma-3-4b-it,eager_fail_to_run,0
openai/whisper-tiny,eager_fail_to_run,0
Qwen/Qwen3-0.6B,eager_fail_to_run,0

1 name accuracy graph_breaks
171
172
173

View File

@ -61,22 +61,6 @@ struct C10_API Storage {
allocator,
resizable)) {}
// Creates storage with pre-allocated memory buffer. Allocator is given for
// potential future reallocations, however it can be nullptr if the storage
// is non-resizable
Storage(
use_byte_size_t /*use_byte_size*/,
SymInt size_bytes,
at::DataPtr data_ptr,
at::Allocator* allocator = nullptr,
bool resizable = false)
: storage_impl_(c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
std::move(size_bytes),
std::move(data_ptr),
allocator,
resizable)) {}
protected:
explicit Storage(unsafe_borrow_t, const Storage& rhs)
: storage_impl_(c10::intrusive_ptr<c10::StorageImpl>::reclaim(

View File

@ -3269,7 +3269,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
is_le<sizeof(autograd_meta_), 16, FieldNameEnum::autograd_meta_>();
is_le<sizeof(extra_meta_), 16, FieldNameEnum::extra_meta_>();
are_equal<sizeof(version_counter_), 8, FieldNameEnum::version_counter_>();
are_equal<sizeof(pyobj_slot_), 8, FieldNameEnum::pyobj_slot_>();
are_equal<sizeof(pyobj_slot_), 16, FieldNameEnum::pyobj_slot_>();
are_equal<sizeof(sizes_and_strides_), 88, FieldNameEnum::sizes_and_strides_>();
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();

View File

@ -13,10 +13,11 @@ struct C10_API PyInterpreterHooksInterface {
// Get the PyInterpreter instance
// Stub implementation throws error when Python is not available
// We return nullptr rather than throwing an error since there are bits of c10
// that expect an empty PyObjectSlot when python is not available.
virtual PyInterpreter* getPyInterpreter() const {
return nullptr;
TORCH_CHECK(
false,
"PyTorch was compiled without Python support. "
"Cannot access Python interpreter from C++.");
}
};

View File

@ -2,7 +2,7 @@
namespace c10::impl {
PyObjectSlot::PyObjectSlot() : pyobj_(nullptr) {}
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
@ -10,9 +10,9 @@ PyObjectSlot::~PyObjectSlot() {
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(getGlobalPyInterpreter() != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*getGlobalPyInterpreter())
(*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
@ -25,7 +25,7 @@ void PyObjectSlot::maybe_destroy_pyobj() {
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return getGlobalPyInterpreter();
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
@ -35,7 +35,7 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = getGlobalPyInterpreter();
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}

View File

@ -6,17 +6,10 @@
#include <c10/util/python_stub.h>
#include <optional>
#include <atomic>
namespace c10::impl {
// Function pointer type for getting the global interpreter
using GetPyInterpreterFn = PyInterpreter* (*)();
// Global function pointer (set by csrc initialization)
C10_API extern GetPyInterpreterFn g_get_pyinterpreter_fn;
// Helper function to get the global interpreter
C10_API PyInterpreter* getGlobalPyInterpreter();
struct C10_API PyObjectSlot {
public:
PyObjectSlot();
@ -33,6 +26,8 @@ struct C10_API PyObjectSlot {
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
@ -60,15 +55,18 @@ struct C10_API PyObjectSlot {
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj() const {
impl::PyInterpreter* interpreter = getGlobalPyInterpreter();
if (interpreter == nullptr || pyobj_ == nullptr) {
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
return std::nullopt;
}
if (c10::impl::HermeticPyObjectTLS::get_state()) {
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
return _unchecked_untagged_pyobj();
}
PyInterpreter& load_pyobj_interpreter() const;
@ -78,6 +76,30 @@ struct C10_API PyObjectSlot {
void set_owns_pyobj(bool b);
private:
// This field contains the interpreter tag for this object. See
// Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
std::atomic<PyInterpreter*> pyobj_interpreter_;
// This field contains a reference to a PyObject representing this Tensor.
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// PyObject for it and set this field. This field does not have to be

View File

@ -52,7 +52,7 @@ struct maybe_bool {
template <typename src_t>
struct maybe_bool<true, src_t> {
C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) {
// Don't use bool operator so as to also compile for ComplexHalf.
// Don't use bool operator so as to to also compile for ComplexHalf.
return src.real() || src.imag();
}
};

View File

@ -1,243 +0,0 @@
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(CAFFE2_PERF_WITH_SVE128)
#include <arm_neon.h>
#include <arm_neon_sve_bridge.h>
#include <arm_sve.h>
#include "c10/macros/Macros.h"
// Log and exp approximations inspired from ACL implementation
inline float32x4_t vtaylor_polyq_for_log_f32(float32x4_t x) {
const float32x4_t log_tab_1 = vdupq_n_f32(-2.29561495781f);
const float32x4_t log_tab_2 = vdupq_n_f32(-2.47071170807f);
const float32x4_t log_tab_3 = vdupq_n_f32(-5.68692588806f);
const float32x4_t log_tab_4 = vdupq_n_f32(-0.165253549814f);
const float32x4_t log_tab_5 = vdupq_n_f32(5.17591238022f);
const float32x4_t log_tab_6 = vdupq_n_f32(0.844007015228f);
const float32x4_t log_tab_7 = vdupq_n_f32(4.58445882797f);
const float32x4_t log_tab_8 = vdupq_n_f32(0.0141278216615f);
float32x4_t A = vmlaq_f32(log_tab_1, log_tab_5, x);
float32x4_t B = vmlaq_f32(log_tab_3, log_tab_7, x);
float32x4_t C = vmlaq_f32(log_tab_2, log_tab_6, x);
float32x4_t x2 = vmulq_f32(x, x);
float32x4_t D = svget_neonq(svmad_f32_x(
svptrue_b8(),
svset_neonq(svundef_f32(), x),
svset_neonq(svundef_f32(), log_tab_8),
svset_neonq(svundef_f32(), log_tab_4)));
float32x4_t x4 = vmulq_f32(x2, x2);
float32x4_t res = vmlaq_f32(vmlaq_f32(A, B, x2), vmlaq_f32(C, D, x2), x4);
return res;
}
inline float32x4_t vlogq_f32(float32x4_t x) {
const float32x4_t CONST_LN2 = vdupq_n_f32(0.6931471805f); // ln(2)
// Extract exponent
int32x4_t m = svget_neonq(svsub_n_s32_x(
svptrue_b8(),
svset_neonq(
svundef_s32(),
vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_f32(x), 23))),
127));
float32x4_t val = vreinterpretq_f32_s32(
vsubq_s32(vreinterpretq_s32_f32(x), vshlq_n_s32(m, 23)));
// Polynomial Approximation
float32x4_t poly = vtaylor_polyq_for_log_f32(val);
// Reconstruct
poly = vmlaq_f32(poly, vcvtq_f32_s32(m), CONST_LN2);
return poly;
}
inline float32x4_t vexpq_f32(float32x4_t x) {
const auto c1 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3f7ffff6)));
const auto c2 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3efffedb)));
const auto c3 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3e2aaf33)));
const auto c4 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3d2b9f17)));
const auto c5 = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(0x3c072010)));
const auto shift = vreinterpretq_f32_u32(
svget_neonq(svdup_n_u32(0x4b00007f))); // 2^23 + 127 = 0x1.0000fep23f
const auto inv_ln2 = vreinterpretq_f32_u32(
svget_neonq(svdup_n_u32(0x3fb8aa3b))); // 1 / ln(2) = 0x1.715476p+0f
const auto neg_ln2_hi = vreinterpretq_f32_u32(svget_neonq(
svdup_n_u32(0xbf317200))); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f
const auto neg_ln2_lo = vreinterpretq_f32_u32(svget_neonq(svdup_n_u32(
0xb5bfbe8e))); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f
const auto inf = svdup_n_f32(std::numeric_limits<float>::infinity());
const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5)
const auto zero = svdup_n_f32(0.f);
const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125)
// Range reduction:
// e^x = 2^n * e^r
// where:
// n = floor(x / ln(2))
// r = x - n * ln(2)
//
// By adding x / ln(2) with 2^23 + 127 (shift):
// * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127
// forces decimal part
// of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n)
// + 127 will occupy the whole fraction part of z in FP32 format.
// Subtracting 2^23 + 127 (shift) from z will result in the integer part
// of x / ln(2) (i.e. n) because the decimal part has been pushed out and
// lost.
// * The addition of 127 makes the FP32 fraction part of z ready to be used
// as the exponent
// in FP32 format. Left shifting z by 23 bits will result in 2^n.
const auto z = vfmaq_f32(shift, x, inv_ln2);
const auto n = z - shift;
const auto scale =
vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n
// The calculation of n * ln(2) is done using 2 steps to achieve accuracy
// beyond FP32. This outperforms longer Taylor series (3-4 tabs) both in term
// of accuracy and performance.
const auto r_hi = vfmaq_f32(x, n, neg_ln2_hi);
const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo);
// Compute the truncated Taylor series of e^r.
// poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5)
const auto r2 = r * r;
const auto p1 = c1 * r;
const auto p23 = vfmaq_f32(c2, c3, r);
const auto p45 = vfmaq_f32(c4, c5, r);
const auto p2345 = vfmaq_f32(p23, p45, r2);
const auto p12345 = vfmaq_f32(p1, p2345, r2);
auto poly = svset_neonq(svundef_f32(), vfmaq_f32(scale, p12345, scale));
// Handle underflow and overflow.
poly = svsel_f32(
svcmplt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), min_input),
zero,
poly);
poly = svsel_f32(
svcmpgt_f32(svptrue_b8(), svset_neonq(svundef_f32(), x), max_input),
inf,
poly);
return svget_neonq(poly);
}
// ln(x) = log2(x) * ln(2)
// pow(x, n) = exp(n * ln(x))
inline float32x4_t compute_batch_box_cox_vec_sve128_float(
svfloat32_t lambda1_v,
svfloat32_t lambda2_v,
svfloat32_t data_v,
svfloat32_t k_eps) {
// sum_v = lambda2_v + data_v
float32x4_t sum_v = vaddq_f32(svget_neonq(data_v), svget_neonq(lambda2_v));
// test lambda1_v: predNZ == 1 iff lambda1_v != 0
svbool_t predNZ = svcmpne_n_f32(svptrue_b8(), lambda1_v, 0.0f);
// clamp sum_v: sum_v = max(sum_v, k_eps)
sum_v = vmaxq_f32(sum_v, svget_neonq(k_eps));
// lnData = log(sum_v)
svfloat32_t lnData = svset_neonq(svundef_f32(), vlogq_f32(sum_v));
// if any lambda1 != 0, compute pow(sum_v, lambda1) using lnData
// pow(sum_v, lambda1) == exp(lambda1 * ln(sum_v))
if (C10_LIKELY(svptest_any(predNZ, predNZ))) {
// mult = lambda1 * ln(sum_v)
float32x4_t mult = vmulq_f32(svget_neonq(lnData), svget_neonq(lambda1_v));
// lambda1_r = 1 / lambda1
svfloat32_t lambda1_r = svdivr_f32_m(predNZ, lambda1_v, svdup_n_f32(1.0f));
// pow = exp(mult)
float32x4_t pow = vexpq_f32(mult);
// merge results
// lnData if lambda1 == 0, (lambda1_r * pow - lambda1_r) if lambda1 != 0
lnData = svsel_f32(predNZ, lambda1_r, lnData);
lnData =
svnmsb_f32_m(predNZ, lnData, svset_neonq(svundef_f32(), pow), lnData);
}
return svget_neonq(lnData);
}
template <typename T>
void compute_batch_box_cox_vec_sve128(
std::size_t N,
std::size_t D,
const T* data_ptr,
const T* __restrict lambda1_ptr,
const T* __restrict lambda2_ptr,
T* output_ptr);
template <>
void compute_batch_box_cox_vec_sve128(
std::size_t N,
std::size_t D,
const float* data_ptr,
const float* __restrict lambda1_ptr,
const float* __restrict lambda2_ptr,
float* output_ptr) {
svfloat32_t k_eps = svdup_n_f32(static_cast<float>(1e-6));
std::size_t remainder = D % 4;
std::size_t loopBound = D - remainder;
svbool_t remainderPred = svwhilelt_b32_u64(0, remainder);
for (; C10_LIKELY(N > 0); --N) {
for (std::size_t j = 0; C10_LIKELY(j != loopBound);
j += 4, data_ptr += 4, output_ptr += 4) {
svfloat32_t lambda1_v =
svset_neonq(svundef_f32(), vld1q_f32(lambda1_ptr + j));
svfloat32_t lambda2_v =
svset_neonq(svundef_f32(), vld1q_f32(lambda2_ptr + j));
svfloat32_t data_v = svset_neonq(svundef_f32(), vld1q_f32(data_ptr));
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
lambda1_v, lambda2_v, data_v, k_eps);
vst1q_f32(output_ptr, result);
}
if (C10_LIKELY(remainder > 0)) {
svfloat32_t lambda1_v = svld1_f32(remainderPred, lambda1_ptr + loopBound);
svfloat32_t lambda2_v = svld1_f32(remainderPred, lambda2_ptr + loopBound);
svfloat32_t data_v = svld1_f32(remainderPred, data_ptr);
float32x4_t result = compute_batch_box_cox_vec_sve128_float(
lambda1_v, lambda2_v, data_v, k_eps);
svst1_f32(remainderPred, output_ptr, svset_neonq(svundef_f32(), result));
data_ptr += remainder;
output_ptr += remainder;
}
}
}
namespace caffe2::details {
template <typename T>
void compute_batch_box_cox__sve128(
std::size_t N,
std::size_t D,
const T* self_data,
const T* __restrict lambda1_data,
const T* __restrict lambda2_data,
T* output_data) {
compute_batch_box_cox_vec_sve128<T>(
N, D, self_data, lambda1_data, lambda2_data, output_data);
}
// Vectorized version specializations for float and double
template void compute_batch_box_cox__sve128<float>(
std::size_t N,
std::size_t D,
const float* self_data,
const float* __restrict lambda1_data,
const float* __restrict lambda2_data,
float* output_data);
} // namespace caffe2::details
#endif // __aarch64__ && __ARM_FEATURE_SVE && CAFFE2_PERF_WITH_SVE128

View File

@ -107,12 +107,6 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a")
endif()
endif()
# We will need to gate against CUDA version, because sm_103a is available on CUDA 12.9+
if("${_arch}" STREQUAL "103a" AND CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
if(_existing_arch_flags MATCHES ".*compute_100.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_103a,code=sm_103a")
endif()
endif()
if("${_arch}" STREQUAL "120a")
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
@ -126,13 +120,13 @@ if(INTERN_BUILD_ATEN_OPS)
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;103a;120a")
"89;90a;100a;120a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
"90a;100a;103a")
"90a;100a")
endif()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 563 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 281 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 348 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 747 KiB

View File

@ -210,6 +210,10 @@ templates_path = [
coverage_ignore_functions = [
# torch
"typename",
# torch.cuda
"check_error",
"cudart",
"is_bf16_supported",
# torch.cuda._sanitizer
"zip_arguments",
"zip_by_key",
@ -3176,8 +3180,6 @@ coverage_ignore_classes = [
"WeakIdKeyDictionary",
"WeakIdRef",
"WeakTensorKeyDictionary",
# torch.utils.debug_mode
"DebugMode",
]
# The suffix(es) of source filenames.

View File

@ -15,7 +15,6 @@
StreamContext
can_device_access_peer
check_error
current_blas_handle
current_device
current_stream
@ -35,7 +34,6 @@
init
ipc_collect
is_available
is_bf16_supported
is_initialized
is_tf32_supported
memory_usage

View File

@ -22,7 +22,6 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
device_count
init
is_available
is_bf16_supported
is_initialized
memory_stats
get_device_capability

View File

@ -30,6 +30,9 @@
.. autofunction:: create_mask
```
```{eval-rst}
.. autofunction:: create_nested_block_mask
```
```{eval-rst}
.. autofunction:: and_masks
```
```{eval-rst}

View File

@ -3,6 +3,12 @@
TorchInductor and AOTInductor Provenance Tracking
=================================================
.. warning::
This feature is a prototype under active development and there will be
breaking change in future releases.
The current compatibility of this tool is limited to the latest nightly build of PyTorch.
This section describes how to use the provenance tracking feature for TorchInductor and AOTInductor in ``tlparse``.
Provenance tracking helps you visualize the relationships between the input GraphModule to (AOT)Inductor and the optimized code generated. This feature allows you to trace how your original operations are transformed during compilation.
@ -31,7 +37,7 @@ Follow these steps to enable and use provenance tracking in your PyTorch project
.. code-block:: bash
TORCH_TRACE=~/my_trace_log_dir INDUCTOR_PROVENANCE=1 python your_program.py
TORCH_TRACE=~/my_trace_log_dir TORCH_LOGS="+inductor" TORCH_COMPILE_DEBUG=1 python your_program.py
This will generate a log file in ``~/my_trace_log_dir``. The log file will be used by tlparse to generate the provenance tracking highlighter.
3. Run ``tlparse`` on the log with ``--inductor-provenance`` flag. For example:
@ -56,24 +62,6 @@ For a demo, see: https://github.com/pytorch/tlparse/pull/93
.. image:: _static/img/inductor_provenance/index.png
Source code corresponding to each Inductor kernel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
With ``INDUCTOR_PROVENANCE=1``, you can also view the source code corresponding to each Inductor kernel in tlparse. To access it, click the "readable_html" link next to "inductor_provenance_tracking_kernel_stack_traces.json" in the tlparse output.
.. image:: _static/img/inductor_provenance/index_2.png
Below are some example screenshots. The ``:1`` and ``:467`` suffixes at the end of the kernel names are used to distinguish different calls to the same kernel. We refer to these suffixes as debug handles.
.. image:: _static/img/inductor_provenance/kernel_source_1.png
.. image:: _static/img/inductor_provenance/kernel_source_2.png
You can also find the debug handle in the comments within the kernel source code.
.. image:: _static/img/inductor_provenance/kernel_source_3.png
See Also
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -78,7 +78,6 @@ for tracking purposes -->
.. py:module:: torch.utils.data.graph
.. py:module:: torch.utils.data.graph_settings
.. py:module:: torch.utils.data.sampler
.. py:module:: torch.utils.debug_mode
.. py:module:: torch.utils.dlpack
.. py:module:: torch.utils.file_baton
.. py:module:: torch.utils.flop_counter

View File

@ -12,7 +12,6 @@
:nosignatures:
StreamContext
can_device_access_peer
current_device
current_stream
device
@ -26,7 +25,6 @@
get_stream_from_external
init
is_available
is_bf16_supported
is_initialized
set_device
set_stream

View File

@ -1187,7 +1187,8 @@ int64_t _Tensor_ndim(mpy::handle h) {
mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
// fast case: tensor is live in python
std::optional<PyObject*> mb_obj =
t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj();
t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
if (mb_obj.has_value() &&
!t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
return *mb_obj;

View File

@ -879,15 +879,12 @@ void test_cuda_alloc_test() {
if (cudaStatus != cudaSuccess || device_idx == -1) {
throw std::runtime_error("cudaGetDevice failed!");
}
c10::cuda::CUDACachingAllocator::emptyCache();
c10::cuda::CUDACachingAllocator::DeviceStats stats =
c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
size_t initTorchActive = stats.allocated_bytes[0].current;
size_t initTorchActive = stats.active_bytes[0].current;
auto runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
model_so_path);
stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx);
size_t torchActive = stats.allocated_bytes[0].current;
size_t torchActive = stats.active_bytes[0].current;
ASSERT_EQ(initTorchActive + DATASIZE, torchActive);
@ -1116,7 +1113,8 @@ TEST(AotInductorTest, MultiStreamTestCuda) {
test_multi_cuda_streams("cuda");
}
TEST(AotInductorTest, CudaAllocTestCuda) {
// TODO: ENABLE CUDACachingAllocator Test
TEST(DISABLED_AotInductorTest, CudaAllocTestCuda) {
test_cuda_alloc_test();
}
#endif

View File

@ -1490,8 +1490,8 @@ class TestFullyShardWorldSize1(FSDPTest):
@skip_if_lt_x_gpu(1)
def test_train_parity_single_worldsize1(self):
"""
Tests train parity with DDP for a single FSDP group
when sharding parameters on dim-0.
Tests train parity with DDP for a single FSDP group when sharding
parameters on dim-0.
"""
self.run_subtests(
{
@ -1539,7 +1539,9 @@ class TestFullyShardWorldSize1(FSDPTest):
losses.append(model(*inp).sum())
losses[-1].backward()
self.assertEqual(comm_mode.get_total_counts(), 0)
# Before there was 1 all-gather and 1 reduce-scatter
# Now therre is 1 reduce-scatter
self.assertEqual(comm_mode.get_total_counts(), 1)
optim.step()
self.assertEqual(losses[0], losses[1])

View File

@ -294,11 +294,11 @@ class TestFullyShard2DTraining(FSDPTest):
with CommDebugMode() as bwd_comm_mode:
loss.backward()
bwd_comm_counts = bwd_comm_mode.get_comm_counts()
self.assertEqual(len(bwd_comm_counts), 1)
self.assertEqual(len(bwd_comm_counts), 2)
# First MLP's input gradient does not need to be all-reduced
self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], 0)
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], 0)
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
ref_loss.backward()
optim.step()

View File

@ -1,173 +0,0 @@
# Owner(s): ["oncall: distributed"]
import copy
from collections.abc import Iterable
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.replicate_with_fsdp import replicate
from torch.distributed.fsdp import FSDPModule
from torch.distributed.tensor import DTensor
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, get_devtype, MLP
from torch.testing._internal.common_utils import run_tests
c10d_ops = torch.ops.c10d
funcol = torch.ops.c10d_functional
device_type = torch.device(get_devtype())
class TestReplicateForwardInputs(FSDPTestMultiThread):
@property
def world_size(self) -> int:
return 2
@skip_if_lt_x_gpu(1)
def test_root_move_forward_input_to_device(self):
device = torch.device(device_type.type, 0)
class ParamlessModule(nn.Module):
def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]):
# Check that Replicate moved the inputs to GPU, including recursing
# into the tuple data structure
assert x.device == device, f"Expects {device} but got {x.device}"
assert ys[0].device == device, (
f"Expects {device} but got {ys[0].device}"
)
assert ys[1].device == device, (
f"Expects {device} but got {ys[1].device}"
)
y = ys[0] + ys[1]
return x + y + 1
model = ParamlessModule().to(device)
replicate(model).to(device)
x = torch.randn((3,))
ys = (torch.randn((3,)), torch.randn((3,)))
self.assertEqual(x.device, torch.device("cpu"))
self.assertEqual(ys[0].device, torch.device("cpu"))
self.assertEqual(ys[1].device, torch.device("cpu"))
model(x, ys)
class TestReplicateRegisteredParams(FSDPTestMultiThread):
@property
def world_size(self) -> int:
return 4
@skip_if_lt_x_gpu(1)
def test_param_registration_after_forward(self):
"""Tests the parameter registration after forward."""
device = torch.device(device_type.type, 0)
# Single Replicate group
for reshard_after_forward in (True, False, None):
torch.manual_seed(42)
model = MLP(3, device)
# Since seed is per process, not per thread, we broadcast to ensure
# the same parameters across ranks
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 3), device=device_type.type)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
if reshard_after_forward:
self._assert_dtensor_params(model.parameters())
else:
self._assert_tensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
# Multiple Replicate groups
for reshard_after_forward in (True, False, None):
torch.manual_seed(42)
model = nn.Sequential(MLP(3, device), MLP(3, device))
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
non_root_params = list(model[0].in_proj.parameters()) + list(
model[0].out_proj.parameters()
)
root_params = list(set(model.parameters()) - set(non_root_params))
if reshard_after_forward is None:
self._assert_dtensor_params(non_root_params)
self._assert_tensor_params(root_params)
elif reshard_after_forward:
self._assert_dtensor_params(non_root_params)
self._assert_dtensor_params(root_params)
else:
self._assert_tensor_params(non_root_params)
self._assert_tensor_params(root_params)
self._assert_same_params(model.parameters(), ref_model.parameters())
for module in model.modules():
if isinstance(module, FSDPModule):
module.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
@skip_if_lt_x_gpu(1)
def test_param_registration_after_backward(self):
"""Tests the parameter registration after backward."""
device = torch.device(device_type.type, 0)
# Single Replicate group
for reshard_after_forward in (True, False):
model = MLP(8, device)
replicate(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 8), device=device_type.type)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
# Multiple Replicate groups
for reshard_after_forward in (True, False):
model = MLP(8, device)
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
# need to iterate over the list multiple times
params = list(params)
self.assertGreater(len(params), 0)
for param in params:
self.assertNotIsInstance(param, DTensor)
self.assertIsInstance(param, torch.Tensor)
def _assert_dtensor_params(self, params: Iterable[nn.Parameter]):
params = list(params)
self.assertGreater(len(params), 0)
for param in params:
self.assertIsInstance(param, DTensor)
def _assert_same_params(
self, params: Iterable[nn.Parameter], ref_params: Iterable[nn.Parameter]
):
params, ref_params = list(params), list(ref_params)
self.assertEqual(len(params), len(ref_params))
for param, ref_param in zip(params, ref_params):
if isinstance(param, DTensor):
param = param.full_tensor()
self.assertEqual(param.shape, ref_param.shape)
self.assertEqual(param, ref_param)
if __name__ == "__main__":
run_tests()

View File

@ -47,14 +47,11 @@ _LOGGER = logging.getLogger(__name__)
class TestCoalesce(TestCase):
def helper_test_coalesce(self, layout, coalesced_layout=None):
def helper_test_coalesce(self, layout):
layoutR = coalesce(layout)
_LOGGER.debug(f"{layout} => {layoutR}")
if coalesced_layout:
self.assertEqual(coalesced_layout.shape, layoutR.shape)
self.assertEqual(coalesced_layout.stride, layoutR.stride)
self.assertEqual(size(layoutR), size(layout))
for i in range(size(layout)):
@ -85,17 +82,11 @@ class TestCoalesce(TestCase):
layout = Layout((2, (4, 6)))
self.helper_test_coalesce(layout)
layout = Layout((1, 2), (8, 1))
coalesced_layout = Layout(2, 1)
self.helper_test_coalesce(layout, coalesced_layout)
layout = Layout((2, 4), (4, 1))
coalesced_layout = Layout(8, 1)
self.helper_test_coalesce(layout, coalesced_layout)
self.helper_test_coalesce(layout)
layout = Layout((2, 4, 6), (24, 6, 1))
coalesced_layout = Layout(48, 1)
self.helper_test_coalesce(layout, coalesced_layout)
self.helper_test_coalesce(layout)
layout = Layout((2, 1, 3), (2, 4, 4))
self.helper_test_coalesce(layout)
@ -103,10 +94,6 @@ class TestCoalesce(TestCase):
layout = Layout(((2, 2), (2, 2)), ((1, 4), (8, 32)))
self.helper_test_coalesce(layout)
layout = Layout(((2, 2), (2, 2)), ((32, 8), (4, 1)))
coalesced_layout = Layout((2, 4, 2), (32, 4, 1))
self.helper_test_coalesce(layout, coalesced_layout)
if __name__ == "__main__":
run_tests()

View File

@ -208,26 +208,11 @@ class TestComposition(TestCase):
layoutB = Layout((6), (1))
self.helper_test_composition(layoutA, layoutB)
# Pre-coalesced RHS
layoutA = Layout((8, 6, 4), (7, 4, 1))
layoutB = Layout((6), (1))
self.helper_test_composition(layoutA, layoutB)
# Case when not meet stride divisibility condition
with self.assertRaises(AssertionError):
layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7))
layoutB = Layout(6, 12)
self.helper_test_composition(layoutA, layoutB)
# Mid-layout truncation
layoutA = Layout((10, 8, 6, 4), (7, 5, 3, 2))
layoutA = Layout((4, 6, 8, 10), (2, 3, 5, 7))
layoutB = Layout(6, 12)
self.helper_test_composition(layoutA, layoutB)
layoutA = Layout((4,), (3,))
layoutB = Layout((6,), (2,))
self.helper_test_composition(layoutA, layoutB)
if __name__ == "__main__":
run_tests()

View File

@ -67,159 +67,20 @@ class TestIntTuple(TestCase):
self.assertEqual(shape_div((6, (3, 4)), 36), (1, (1, 2)))
def test_suffix_product(self):
self.assertEqual(suffix_product(2), 1)
def test_prefix_product(self):
self.assertEqual(prefix_product(2), 1)
self.assertEqual(suffix_product((3, 2)), (2, 1))
self.assertEqual(prefix_product((3, 2)), (1, 3))
self.assertEqual(suffix_product((3, 2, 4)), (8, 4, 1))
self.assertEqual(prefix_product((3, 2, 4)), (1, 3, 6))
self.assertEqual(suffix_product(((2, 3), 4)), ((12, 4), 1))
self.assertEqual(prefix_product(((2, 3), 4)), ((1, 2), 6))
self.assertEqual(
suffix_product(((2, 3), (2, 1, 2), (5, 2, 1))),
((120, 40), (20, 20, 10), (2, 1, 1)),
prefix_product(((2, 3), (2, 1, 2), (5, 2, 1))),
((1, 2), (6, 12, 12), (24, 120, 240)),
)
def test_crd2idx_basic(self):
# Test basic int/int case
self.assertEqual(crd2idx(2, 5, 1), 2)
self.assertEqual(crd2idx(0, 5, 1), 0)
self.assertEqual(crd2idx(4, 5, 1), 4)
# Test with custom stride
self.assertEqual(crd2idx(2, 5, 3), 6)
self.assertEqual(crd2idx(1, 5, 3), 3)
def test_crd2idx_tuple(self):
# Test tuple coordinates with default stride
self.assertEqual(crd2idx((1, 2), (3, 4)), 6) # 1*4 + 2*1 = 6
self.assertEqual(crd2idx((0, 0), (3, 4)), 0)
self.assertEqual(crd2idx((2, 3), (3, 4)), 11) # 2*4 + 3*1 = 11
# Test with custom stride
self.assertEqual(crd2idx((1, 2), (3, 4), (8, 2)), 12) # 1*8 + 2*2 = 12
# Test 3D case
self.assertEqual(crd2idx((1, 0, 2), (2, 3, 4)), 14) # 1*12 + 0*4 + 2*1 = 14
def test_crd2idx_none(self):
# Test None coordinate (should default to 0)
self.assertEqual(crd2idx(None, 5), 0)
self.assertEqual(crd2idx(None, (3, 4)), 0)
def test_crd2idx_int_with_tuple_shape(self):
# Test single integer coordinate with multi-dimensional shape and stride
# When crd is int and shape is tuple, it converts the int to multi-dim coordinate first
self.assertEqual(crd2idx(0, (2, 2), (2, 1)), 0) # 0 -> (0,0) -> 0*2 + 0*1 = 0
self.assertEqual(crd2idx(1, (2, 2), (2, 1)), 1) # 1 -> (0,1) -> 0*2 + 1*1 = 1
self.assertEqual(crd2idx(2, (2, 2), (2, 1)), 2) # 2 -> (1,0) -> 1*2 + 0*1 = 2
self.assertEqual(crd2idx(3, (2, 2), (2, 1)), 3) # 3 -> (1,1) -> 1*2 + 1*1 = 3
# Test with non-trivial strides
self.assertEqual(crd2idx(0, (2, 3), (6, 2)), 0) # 0 -> (0,0) -> 0*6 + 0*2 = 0
self.assertEqual(crd2idx(1, (2, 3), (6, 2)), 2) # 1 -> (0,1) -> 0*6 + 1*2 = 2
self.assertEqual(crd2idx(2, (2, 3), (6, 2)), 4) # 2 -> (0,2) -> 0*6 + 2*2 = 4
self.assertEqual(crd2idx(3, (2, 3), (6, 2)), 6) # 3 -> (1,0) -> 1*6 + 0*2 = 6
self.assertEqual(crd2idx(4, (2, 3), (6, 2)), 8) # 4 -> (1,1) -> 1*6 + 1*2 = 8
self.assertEqual(crd2idx(5, (2, 3), (6, 2)), 10) # 5 -> (1,2) -> 1*6 + 2*2 = 10
# Test with larger strides
self.assertEqual(crd2idx(0, (3, 2), (10, 5)), 0) # 0 -> (0,0) -> 0*10 + 0*5 = 0
self.assertEqual(crd2idx(1, (3, 2), (10, 5)), 5) # 1 -> (0,1) -> 0*10 + 1*5 = 5
self.assertEqual(
crd2idx(2, (3, 2), (10, 5)), 10
) # 2 -> (1,0) -> 1*10 + 0*5 = 10
self.assertEqual(
crd2idx(3, (3, 2), (10, 5)), 15
) # 3 -> (1,1) -> 1*10 + 1*5 = 15
self.assertEqual(
crd2idx(4, (3, 2), (10, 5)), 20
) # 4 -> (2,0) -> 2*10 + 0*5 = 20
self.assertEqual(
crd2idx(5, (3, 2), (10, 5)), 25
) # 5 -> (2,1) -> 2*10 + 1*5 = 25
# Test with 3D shape and various strides
self.assertEqual(
crd2idx(0, (2, 2, 2), (8, 4, 2)), 0
) # 0 -> (0,0,0) -> 0*8 + 0*4 + 0*2 = 0
self.assertEqual(
crd2idx(1, (2, 2, 2), (8, 4, 2)), 2
) # 1 -> (0,0,1) -> 0*8 + 0*4 + 1*2 = 2
self.assertEqual(
crd2idx(2, (2, 2, 2), (8, 4, 2)), 4
) # 2 -> (0,1,0) -> 0*8 + 1*4 + 0*2 = 4
self.assertEqual(
crd2idx(3, (2, 2, 2), (8, 4, 2)), 6
) # 3 -> (0,1,1) -> 0*8 + 1*4 + 1*2 = 6
self.assertEqual(
crd2idx(4, (2, 2, 2), (8, 4, 2)), 8
) # 4 -> (1,0,0) -> 1*8 + 0*4 + 0*2 = 8
self.assertEqual(
crd2idx(7, (2, 2, 2), (8, 4, 2)), 14
) # 7 -> (1,1,1) -> 1*8 + 1*4 + 1*2 = 14
self.assertEqual(
crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8
) # 4 -> (1,0,0) -> 1*8 = 8
def test_idx2crd_basic(self):
# Test basic int/int case
self.assertEqual(idx2crd(2, 5, 1), 2)
self.assertEqual(idx2crd(0, 5, 1), 0)
self.assertEqual(idx2crd(4, 5, 1), 4)
# Test with custom stride
self.assertEqual(idx2crd(6, 5, 3), 2) # (6 // 3) % 5 = 2
self.assertEqual(idx2crd(3, 5, 3), 1) # (3 // 3) % 5 = 1
def test_idx2crd_tuple(self):
# Test tuple shape with default stride
self.assertEqual(idx2crd(6, (3, 4)), (1, 2)) # 6 -> (1, 2)
self.assertEqual(idx2crd(0, (3, 4)), (0, 0))
self.assertEqual(idx2crd(11, (3, 4)), (2, 3))
# Test 3D case
self.assertEqual(idx2crd(14, (2, 3, 4)), (1, 0, 2))
def test_crd2idx_idx2crd_roundtrip(self):
# Test that crd2idx and idx2crd are inverse operations
shapes = [
5,
(3, 4),
(2, 3, 4),
(2, 2, 2, 2),
]
for shape in shapes:
size = product(shape)
for idx in range(size):
crd = idx2crd(idx, shape)
recovered_idx = crd2idx(crd, shape)
self.assertEqual(
recovered_idx, idx, f"Failed roundtrip for shape {shape}, idx {idx}"
)
def test_idx2crd_crd2idx_roundtrip(self):
# Test roundtrip starting from coordinates
test_cases = [
(0, 5),
(4, 5),
((0, 0), (3, 4)),
((1, 2), (3, 4)),
((2, 3), (3, 4)),
((0, 0, 0), (2, 3, 4)),
((1, 2, 3), (2, 3, 4)),
]
for crd, shape in test_cases:
idx = crd2idx(crd, shape)
recovered_crd = idx2crd(idx, shape)
self.assertEqual(
recovered_crd, crd, f"Failed roundtrip for crd {crd}, shape {shape}"
)
if __name__ == "__main__":
run_tests()

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import unittest
import torch

View File

@ -1,4 +1,4 @@
# Owner(s): ["module: fsdp"]
# Owner(s): ["module: unknown"]
import functools
import gc
from typing import Union

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import gc
import unittest

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
from copy import copy

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import unittest
from dataclasses import dataclass
from typing import Any, Callable, cast, Union

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import unittest
import torch

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import copy
import unittest

View File

@ -1,21 +1,21 @@
#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]
import functools
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
# LICENSE file in the root directory of this source tree.
import functools
import os
import signal
import unittest
import uuid
from multiprocessing.pool import ThreadPool
from typing import Any
from unittest.mock import call, MagicMock, patch
from unittest.mock import call, patch
import torch.distributed as dist
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
@ -29,7 +29,6 @@ from torch.distributed.elastic.agent.server.api import (
WorkerSpec,
WorkerState,
)
from torch.distributed.elastic.events import EventSource
from torch.distributed.elastic.multiprocessing import SignalException
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
@ -158,243 +157,6 @@ def monres(state: WorkerState):
return RunResult(state=state)
class RecordWorkerEventsTest(unittest.TestCase):
def setUp(self):
self.spec = MagicMock()
self.spec.role = "test_role"
self.spec.get_entrypoint_name.return_value = "test_entrypoint"
self.spec.rdzv_handler.get_run_id.return_value = "test_run_id"
self.spec.rdzv_handler.get_backend.return_value = "test_backend"
self.spec.max_restarts = 3
self.agent = TestAgent(self.spec)
# Create a mock worker spec and agent
self.agent._worker_group = MagicMock()
self.agent._worker_group.spec = MagicMock()
self.agent._worker_group.spec.event_log_handler = "test_handler"
# Setup worker group
self.worker_group = WorkerGroup(self.spec)
self.worker_group.group_world_size = 2
self.worker_group.group_rank = 1
self.agent._worker_group = self.worker_group
# Create a test worker
self.workers = [
Worker(
local_rank=0,
global_rank=0,
role_rank=0,
world_size=2,
role_world_size=2,
),
Worker(
local_rank=1,
global_rank=1,
role_rank=1,
world_size=2,
role_world_size=2,
),
]
self.workers[0].id = 0
self.workers[1].id = 1
self.agent._worker_group.workers = self.workers
@patch("torch.distributed.elastic.agent.server.api.record")
def test_record_worker_events_success(self, mock_record):
# Create a RunResult with successful workers
result = RunResult(
state=WorkerState.SUCCEEDED,
return_values={0: "result0", 1: "result1"},
failures={},
)
# Call the method under test
self.agent._record_worker_events(result)
# Verify record was called twice (once for each worker)
self.assertEqual(mock_record.call_count, 2)
# Check that both calls were for SUCCEEDED events
for call_args in mock_record.call_args_list:
event = call_args[0][0]
self.assertEqual(event.source, EventSource.WORKER)
self.assertEqual(event.metadata["state"], "SUCCEEDED")
self.assertIsNone(event.metadata["raw_error"])
md = json.loads(event.metadata["metadata"])
self.assertEqual(md["exit_code"], [None])
self.assertEqual(md["worker_pid"], [None])
@patch("torch.distributed.elastic.agent.server.api.record")
def test_record_worker_events_failure(self, mock_record):
# Create failures with error data
failure0 = ProcessFailure(
local_rank=0, pid=1000, exitcode=1, error_file="error0.json"
)
# Create a RunResult with one failed worker and one terminated worker
result = RunResult(
state=WorkerState.FAILED,
return_values={},
failures={0: failure0}, # Only worker 0 has a specific failure
)
# Call the method under test
self.agent._record_worker_events(result)
# Verify record was called twice (once for each worker)
self.assertEqual(mock_record.call_count, 2)
# Get the calls
calls = mock_record.call_args_list
# Check first call for the failed worker (global_rank=0)
failed_event = calls[0][0][0]
self.assertEqual(failed_event.source, EventSource.WORKER)
self.assertEqual(failed_event.metadata["state"], "FAILED")
self.assertEqual(failed_event.metadata["global_rank"], 0)
md = json.loads(failed_event.metadata["metadata"])
self.assertEqual(failed_event.metadata["raw_error"], '{"message": "<NONE>"}')
self.assertEqual(md["exit_code"], [1])
self.assertEqual(md["worker_pid"], [1000])
# Check second call for the terminated worker (global_rank=1)
terminated_event = calls[1][0][0]
self.assertEqual(terminated_event.source, EventSource.WORKER)
self.assertEqual(terminated_event.metadata["state"], "TERMINATED")
self.assertEqual(terminated_event.metadata["global_rank"], 1)
self.assertIsNone(terminated_event.metadata["raw_error"])
md = json.loads(terminated_event.metadata["metadata"])
self.assertEqual(md["exit_code"], [None])
self.assertEqual(md["worker_pid"], [None])
class ConstructEventTest(unittest.TestCase):
def setUp(self):
# Create minimal spec and agent for testing
self.spec = MagicMock()
self.spec.role = "test_role"
self.spec.get_entrypoint_name.return_value = "test_entrypoint"
self.spec.rdzv_handler.get_run_id.return_value = "test_run_id"
self.spec.rdzv_handler.get_backend.return_value = "test_backend"
self.spec.max_restarts = 3
self.agent = TestAgent(self.spec)
self.agent._remaining_restarts = 2
self.agent._total_execution_time = 42
# Setup worker group
self.worker_group = WorkerGroup(self.spec)
self.worker_group.group_world_size = 2
self.worker_group.group_rank = 1
self.agent._worker_group = self.worker_group
# Create a test worker
self.worker = Worker(
local_rank=0, global_rank=5, role_rank=3, world_size=8, role_world_size=4
)
self.worker.id = 12345
def test_construct_event_agent_success(self):
# Test constructing an agent success event
event = self.agent._construct_event(state="SUCCEEDED", source=EventSource.AGENT)
# Verify basic event properties
self.assertEqual(event.name, "torchelastic.worker.status.SUCCEEDED")
self.assertEqual(event.source, EventSource.AGENT)
# Verify metadata
metadata = event.metadata
self.assertEqual(metadata["run_id"], "test_run_id")
self.assertIsNone(metadata["global_rank"])
self.assertEqual(metadata["group_rank"], 1)
self.assertIsNone(metadata["worker_id"])
self.assertEqual(metadata["role"], "test_role")
self.assertEqual(metadata["state"], "SUCCEEDED")
self.assertEqual(metadata["total_run_time"], 42)
self.assertEqual(metadata["rdzv_backend"], "test_backend")
self.assertIsNone(metadata["raw_error"])
self.assertEqual(
metadata["agent_restarts"], 1
) # max_restarts - remaining_restarts
self.assertIsNone(metadata["duration_ms"])
# Verify JSON metadata
md_dict = json.loads(metadata["metadata"])
self.assertEqual(md_dict["group_world_size"], 2)
self.assertEqual(md_dict["entry_point"], "test_entrypoint")
def test_construct_event_worker_failure(self):
# Test constructing a worker failure event with raw error
raw_error = json.dumps(
{"error_message": "Test error", "traceback": "stack trace"}
)
event = self.agent._construct_event(
state="FAILED",
source=EventSource.WORKER,
worker=self.worker,
raw_error=raw_error,
exit_code=1,
)
# Verify basic event properties
self.assertEqual(event.name, "torchelastic.worker.status.FAILED")
self.assertEqual(event.source, EventSource.WORKER)
# Verify metadata
metadata = event.metadata
self.assertEqual(metadata["run_id"], "test_run_id")
self.assertEqual(metadata["global_rank"], 5)
self.assertEqual(metadata["group_rank"], 1)
self.assertEqual(metadata["worker_id"], "12345")
self.assertEqual(metadata["role"], "test_role")
self.assertEqual(metadata["state"], "FAILED")
self.assertEqual(metadata["total_run_time"], 42)
self.assertEqual(metadata["rdzv_backend"], "test_backend")
self.assertEqual(metadata["raw_error"], raw_error)
self.assertEqual(metadata["agent_restarts"], 1)
# Verify worker-specific metadata
md_dict = json.loads(metadata["metadata"])
self.assertEqual(md_dict["local_rank"], [0])
self.assertEqual(md_dict["role_rank"], [3])
self.assertEqual(md_dict["role_world_size"], [4])
self.assertEqual(md_dict["exit_code"], [1])
def test_construct_event_with_duration(self):
# Test constructing an event with duration_ms
event = self.agent._construct_event(
state="RENDEZVOUS", source=EventSource.AGENT, duration_ms=123.45
)
# Verify duration is set correctly
self.assertEqual(event.metadata["duration_ms"], 123.45)
def test_construct_event_worker_no_error(self):
# Test constructing a worker event without error info
event = self.agent._construct_event(
state="HEALTHY", source=EventSource.WORKER, worker=self.worker
)
# Verify error fields are None
metadata = event.metadata
self.assertIsNone(metadata["raw_error"])
# Check worker info is set
self.assertEqual(metadata["global_rank"], 5)
self.assertEqual(metadata["worker_id"], "12345")
# Check metadata JSON
md_dict = json.loads(metadata["metadata"])
self.assertEqual(md_dict["local_rank"], [0])
self.assertEqual(md_dict["role_rank"], [3])
self.assertEqual(md_dict["role_world_size"], [4])
self.assertNotIn("exit_code", [None])
class SimpleElasticAgentTest(unittest.TestCase):
def _get_worker_spec(
self,

View File

@ -568,8 +568,9 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
)
results = pc.wait(period=0.1)
self.assertTrue(results.is_failed())
self.assertEqual(2, len(results.failures))
self.assertEqual(1, len(results.failures))
failure = results.failures[0]
self.assertEqual(138, failure.exitcode)
@ -582,13 +583,6 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assertTrue(pc._stderr_tail.stopped())
self.assertTrue(pc._stdout_tail.stopped())
failure = results.failures[1]
self.assertEqual(-15, failure.exitcode)
self.assertEqual("SIGTERM", failure.signal_name())
self.assertEqual("<NONE>", failure.error_file_data["message"])
# Assert that the failure message contains expected substrings
self.assertIn("Signal 15 (SIGTERM) received by PID", failure.message)
def test_binary_raises(self):
pc = start_processes(
name="echo",

View File

@ -9,7 +9,6 @@
import argparse
import os
import sys
import time
if __name__ == "__main__":
@ -24,6 +23,5 @@ if __name__ == "__main__":
print(f"exit {exitcode} from {rank}", file=sys.stderr)
sys.exit(exitcode)
else:
time.sleep(1000)
print(f"{args.msg} stdout from {rank}")
print(f"{args.msg} stderr from {rank}", file=sys.stderr)

View File

@ -15,7 +15,6 @@ import uuid
import torch.distributed.elastic.timer as timer
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_MACOS,
IS_WINDOWS,
run_tests,
@ -24,8 +23,8 @@ from torch.testing._internal.common_utils import (
)
# timer is not supported on these platforms
if not (IS_WINDOWS or IS_MACOS or IS_ARM64):
# timer is not supported on windows or macos
if not (IS_WINDOWS or IS_MACOS):
# func2 should time out
def func2(n, file_path):
if file_path is not None:

View File

@ -14,7 +14,6 @@ import time
import torch.distributed.elastic.timer as timer
import torch.multiprocessing as torch_mp
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_MACOS,
IS_WINDOWS,
run_tests,
@ -41,8 +40,8 @@ def _stuck_function(rank, mp_queue):
time.sleep(5)
# timer is not supported on these platforms
if not (IS_WINDOWS or IS_MACOS or IS_ARM64):
# timer is not supported on macos or windows
if not (IS_WINDOWS or IS_MACOS):
class LocalTimerExample(TestCase):
"""

View File

@ -15,7 +15,6 @@ import torch.distributed.elastic.timer as timer
from torch.distributed.elastic.timer.api import TimerRequest
from torch.distributed.elastic.timer.local_timer import MultiprocessingRequestQueue
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_MACOS,
IS_WINDOWS,
run_tests,
@ -25,10 +24,8 @@ from torch.testing._internal.common_utils import (
)
# timer is not supported on these platforms
INVALID_PLATFORMS = IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN or IS_ARM64
if not INVALID_PLATFORMS:
# timer is not supported on windows or macos
if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
# func2 should time out
def func2(n, mp_queue):
if mp_queue is not None:
@ -132,7 +129,8 @@ if not INVALID_PLATFORMS:
time.sleep(interval)
if not INVALID_PLATFORMS:
# timer is not supported on windows or macos
if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
class MultiprocessingRequestQueueTest(TestCase):
def test_get(self):
@ -199,7 +197,8 @@ if not INVALID_PLATFORMS:
self.assertLessEqual(n / 2, len(requests))
if not INVALID_PLATFORMS:
# timer is not supported on windows or macos
if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
class LocalTimerServerTest(TestCase):
def setUp(self):

View File

@ -41,7 +41,6 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TEST_XPU,
TestCase,
)
from torch.testing._internal.inductor_utils import HAS_GPU
@ -58,8 +57,6 @@ if TEST_WITH_DEV_DBG_ASAN:
)
sys.exit(0)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
"""Tests multiple parameter groups."""
@ -161,7 +158,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
device_init_mode == DEVICEInitMode.DEVICE_AFTER
and not fsdp_model.cpu_offload.offload_params
):
fsdp_model = fsdp_model.to(device=device_type)
fsdp_model = fsdp_model.cuda()
return fsdp_model, fsdp_optim
def _check_train_parity(
@ -174,7 +171,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
num_iters: int = 10,
):
"""Checks training parity between DDP and FSDP."""
device = torch.device(device_type)
device = torch.device("cuda")
for i in range(num_iters):
iter_losses = []
for model, optim in ((ddp_model, ddp_optim), (fsdp_model, fsdp_optim)):
@ -265,7 +262,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
for _ in range(10):
losses = []
inp = ref_model.get_input(torch.device(device_type))
inp = ref_model.get_input(torch.device("cuda"))
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad()
loss = _model(*inp).sum()
@ -473,7 +470,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
):
ddp_optims.append(optim_ctor(ddp_param_group["params"]))
fsdp_optims.append(optim_ctor(fsdp_param_group["params"]))
device = torch.device(device_type)
device = torch.device("cuda")
# Check that there exists a `FlatParameter` that has both a weight and
# a bias in this rank's shard
@ -646,7 +643,7 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
fsdp_model_orig_params,
optim_orig_params,
) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)
device = torch.device(device_type)
device = torch.device("cuda")
for _ in range(3):
inp1 = fsdp_model.get_input(device)
_inp2 = fsdp_model.get_input(device)
@ -704,7 +701,7 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
fsdp_model_orig_params,
optim_orig_params,
) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)
device = torch.device(device_type)
device = torch.device("cuda")
for _ in range(3):
optim.zero_grad()
optim_orig_params.zero_grad()
@ -831,9 +828,9 @@ class TestFSDPUseOrigParamsParamAccess(FSDPTest):
p1 = p1.flatten()
torch.testing.assert_close(p1, p2)
ddp_model = DDP(Model().to(device=device_type), device_ids=[self.rank])
ddp_model = DDP(Model().cuda(), device_ids=[self.rank])
fsdp_model = FSDP(
Model().to(device=device_type),
Model().cuda(),
sharding_strategy=sharding_strategy,
auto_wrap_policy=always_wrap_policy,
use_orig_params=True,
@ -841,7 +838,7 @@ class TestFSDPUseOrigParamsParamAccess(FSDPTest):
LR = 1e-2
ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
device = torch.device(device_type)
device = torch.device("cuda")
inp = fsdp_model.get_input(device)
ddp_out = ddp_model(*inp)
@ -916,11 +913,11 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
# Check that the writeback propagates
ddp_model = DDP(
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
device_ids=[self.rank],
)
fsdp_model = FSDP(
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
use_orig_params=True,
)
ddp = ddp_model.module # for brevity
@ -969,11 +966,11 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
return None if set_to_none else torch.ones_like(param) * 2
ddp_model = DDP(
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
device_ids=[self.rank],
)
fsdp_model = FSDP(
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
use_orig_params=True,
)
LR = 1e-2
@ -984,7 +981,7 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
# Generate an initial gradient
inp = fsdp_model.get_input(torch.device(device_type))
inp = fsdp_model.get_input(torch.device("cuda"))
ddp_out = ddp_model(*inp)
fsdp_out = fsdp_model(*inp)
ddp_out.sum().backward()
@ -1014,7 +1011,7 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
self._check_param_parity(ddp_model, fsdp_model) # triggers a writeback
# Intentionally do not zero the gradient to check writeback
inp = fsdp_model.get_input(torch.device(device_type))
inp = fsdp_model.get_input(torch.device("cuda"))
ddp_out = ddp_model(*inp)
fsdp_out = fsdp_model(*inp)
ddp_out.sum().backward()
@ -1026,7 +1023,7 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_writeback_shape_mismatch(self):
fsdp_model = FSDP(
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),
use_orig_params=True,
)
# Check that writing back with mismatched shape errors
@ -1076,9 +1073,9 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
# Test changing the parameter storage to no longer be a view into the
# flat parameter
fsdp_model = fsdp_wrapper(
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type))
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
)
inp = fsdp_model.get_input(torch.device(device_type))
inp = fsdp_model.get_input(torch.device("cuda"))
loss = fsdp_model(*inp).sum()
fsdp_model.lin1.weight.data = fsdp_model.lin1.weight.clone()
assert_msg = (
@ -1089,9 +1086,9 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
# Test changing the parameter variable itself
fsdp_model = fsdp_wrapper(
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type))
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))
)
inp = fsdp_model.get_input(torch.device(device_type))
inp = fsdp_model.get_input(torch.device("cuda"))
loss = fsdp_model(*inp).sum()
fsdp_model.lin1._fsdp_wrapped_module.weight = nn.Parameter(
fsdp_model.lin1.weight.clone()
@ -1125,10 +1122,9 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
# Train forward -> full-precision unshard -> train forward
fsdp_model = FSDP(
TestFSDPUseOrigParamsWriteback.Model(torch.device(device_type)),
**fsdp_kwargs,
TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), **fsdp_kwargs
)
inp = fsdp_model.get_input(torch.device(device_type))
inp = fsdp_model.get_input(torch.device("cuda"))
fsdp_model(*inp)
with FSDP.summon_full_params(fsdp_model):
...
@ -1187,13 +1183,13 @@ class TestFSDPUseOrigParamsFQNs(FSDPTest):
assert_equal_fn(params[1].shape, param_shapes[1])
return self.lin(x)
model = Model().to(device=device_type)
model = Model().cuda()
# Save the *unsharded* original parameter shapes and check the shapes
# match in the forward pass
param_shapes[0] = model.lin.weight.shape
param_shapes[1] = model.lin.bias.shape
fsdp_model = FSDP(model, use_orig_params=True)
inp = torch.randn((2, 5), device=torch.device(device_type))
inp = torch.randn((2, 5), device=torch.device("cuda"))
fsdp_model(inp)
@ -1220,7 +1216,7 @@ class TestFSDPUseOrigParamsNoSync(FSDPTest):
)
def _test_no_sync_correctness(self, sharding_strategy: ShardingStrategy):
model = nn.Linear(7, 1, bias=False, device=device_type)
model = nn.Linear(7, 1, bias=False, device="cuda")
fsdp_kwargs = {
"sharding_strategy": sharding_strategy,
}
@ -1270,8 +1266,8 @@ class TestFSDPUseOrigParamsNoSync(FSDPTest):
orig_param.grad,
)
inp = torch.randn((2, 7), device=device_type)
grad = torch.randn((2, 1), device=device_type)
inp = torch.randn((2, 7), device="cuda")
grad = torch.randn((2, 1), device="cuda")
# Compute some reference gradients using one forward/backward
out_use_flat_params = model_use_flat_params(inp)
@ -1337,7 +1333,7 @@ class TestFSDPUseOrigParamsNoSync(FSDPTest):
)
def _test_no_sync_mixed_precision(self, sharding_strategy: ShardingStrategy):
model = nn.Linear(3, 3, device=device_type)
model = nn.Linear(3, 3, device="cuda")
mixed_precision = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float32,
@ -1348,7 +1344,7 @@ class TestFSDPUseOrigParamsNoSync(FSDPTest):
"use_orig_params": True,
}
fsdp_model = FSDP(model, **fsdp_kwargs)
inp = torch.randn((2, 3), device=device_type)
inp = torch.randn((2, 3), device="cuda")
with fsdp_model.no_sync():
# For each of these `no_sync()` backward passes, check that the
# gradients are in the low precision parameter dtype (FP16)
@ -1372,8 +1368,8 @@ class TestFSDPUseOrigParamsInit(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_non_uniform_requires_grad(self):
model = nn.Sequential(
nn.Linear(3, 3, device=device_type),
nn.Linear(3, 3, device=device_type),
nn.Linear(3, 3, device="cuda"),
nn.Linear(3, 3, device="cuda"),
)
# Freeze biases only and flatten both weights and biases into the same
# `FlatParameter` to exercise non-uniform `requires_grad`
@ -1396,10 +1392,10 @@ class TestMultiTensorApply(TestCase):
# Check that this does not segfault
torch._foreach_mul_(size0_tensors, 0.1)
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "no cuda and no xpu")
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_multi_tensor_apply_size0_tensors_cuda(self):
size0_tensors = [
torch.empty(0, device=device_type) for _ in range(NUM_SIZE0_TENSORS)
torch.empty(0, device="cuda") for _ in range(NUM_SIZE0_TENSORS)
]
# Check that this does not segfault
torch._foreach_mul_(size0_tensors, 0.1)

View File

@ -15,9 +15,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class TestShardUtilsDistributed(FSDPTest):
@property
def world_size(self):
@ -26,7 +23,7 @@ class TestShardUtilsDistributed(FSDPTest):
def _create_tensor(self, *size):
# Keep everything deterministic.
torch.manual_seed(0)
return torch.rand(*size).to(device=device_type)
return torch.rand(*size).cuda()
@skip_if_lt_x_gpu(2)
def test_create_chunk_sharded_tensor(self):
@ -37,12 +34,10 @@ class TestShardUtilsDistributed(FSDPTest):
tensor,
self.rank,
self.world_size,
torch.accelerator.device_count(),
torch.cuda.device_count(),
_get_default_group(),
)
output = (
torch.empty(*size).to(device=device_type) if self.rank == 0 else None
)
output = torch.empty(*size).cuda() if self.rank == 0 else None
sharded_tensor.gather(0, output)
if self.rank == 0:
self.assertEqual(tensor, output)
@ -56,7 +51,7 @@ class TestShardUtilsDistributedDTensor(DTensorTestBase):
def _create_tensor(self, *size):
# Keep everything deterministic.
torch.manual_seed(0)
return torch.rand(*size).to(device=device_type)
return torch.rand(*size).cuda()
@with_comms
@skip_if_lt_x_gpu(2)

View File

@ -50,15 +50,10 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
TEST_CUDA,
TEST_XPU,
TestCase,
)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = torch.distributed.get_default_backend_for_device(device_type)
class BatchNormNet(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -137,14 +132,14 @@ class TestFSDPWrap(FSDPTest):
class NestedSequentialModel:
@staticmethod
def get_model(device=True):
def get_model(cuda=True):
sequential = nn.Sequential(
nn.Linear(5, 5),
nn.Linear(5, 5),
nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)),
)
if device:
sequential = sequential.to(device=device_type)
if cuda:
sequential = sequential.cuda()
return sequential
@staticmethod
@ -219,7 +214,7 @@ class TestFSDPWrap(FSDPTest):
nested=nested, device_init_mode=device_init_mode
)
if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
wrapped_fsdp = wrapped_fsdp.to(device=device_type)
wrapped_fsdp = wrapped_fsdp.cuda()
wrapped_module_name = "lin1.1" if nested else "lin1"
with self.assertRaisesRegex(
@ -374,7 +369,7 @@ class TestFSDPWrap(FSDPTest):
forward_prefetch=forward_prefetch,
)
if device_init_mode == DEVICEInitMode.DEVICE_AFTER:
wrapped_model = wrapped_model.to(device=device_type)
wrapped_model = wrapped_model.cuda()
modules_in_fsdp_graph_order = [
wrapped_model.module.lin1,
@ -393,7 +388,7 @@ class TestFSDPWrap(FSDPTest):
# Run model a few times for sanity check.
optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9)
inp = torch.ones(1).to(device=device_type)
inp = torch.ones(1).cuda()
for _ in range(6):
optim.zero_grad()
loss = wrapped_model(inp).sum()
@ -466,13 +461,13 @@ class TestAutoWrap(TestCase):
self.assertEqual(layer.rank, 0)
self.assertEqual(layer.world_size, 2)
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test Requires CUDA or XPU")
@unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")
def test_always_wrap(self):
"""
Test to ensure that if `always_wrap_policy` is
passed into FSDP, all submodules are wrapped.
"""
seq = TestFSDPWrap.NestedSequentialModel.get_model(device=True)
seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True)
model = FSDP(
seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy
)
@ -634,7 +629,7 @@ class TestAutoWrap(TestCase):
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
"""
sequential = TestFSDPWrap.NestedSequentialModel.get_model(device=False)
sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=40
)
@ -731,7 +726,7 @@ class TestAutoWrap(TestCase):
self.assertTrue(isinstance(model.module[0], nn.Linear))
self.assertTrue(isinstance(model.module[1], nn.ModuleList))
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test Requires CUDA or XPU")
@unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")
@parametrize(
"device_init_mode", [DEVICEInitMode.DEVICE_BEFORE, DEVICEInitMode.DEVICE_AFTER]
)
@ -748,12 +743,10 @@ class TestAutoWrap(TestCase):
):
return
device = torch.device(device_type)
torch.accelerator.set_device_index(0)
device = torch.device("cuda")
torch.cuda.set_device(0)
device_id = (
torch.device(device_type, torch.accelerator.current_device_index())
if use_device_id
else None
torch.device("cuda", torch.cuda.current_device()) if use_device_id else None
)
# Random port in case the next test run quickly, same port would cause conflict.
@ -762,18 +755,18 @@ class TestAutoWrap(TestCase):
file_name = tempfile.NamedTemporaryFile(delete=False).name
torch.distributed.init_process_group(
backend=backend,
backend="nccl",
init_method=f"{FILE_SCHEMA}_{file_name}",
rank=0,
world_size=1,
)
# NOTE: We move model to GPU after init with FSDP to simulate real use
# NOTE: We move model to CUDA after init with FSDP to simulate real use
# cases where full model cannot be loaded onto GPU, but their shards can.
device_after_init = device_init_mode == DEVICEInitMode.DEVICE_AFTER
cuda_after_init = device_init_mode == DEVICEInitMode.DEVICE_AFTER
try:
sequential = TestFSDPWrap.NestedSequentialModel.get_model(
device=(not device_after_init)
cuda=(not cuda_after_init)
)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=40
@ -785,8 +778,8 @@ class TestAutoWrap(TestCase):
device_id=device_id,
)
TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
if device_after_init:
model = model.to(device=device_type)
if cuda_after_init:
model = model.cuda()
input = torch.rand((1, 5), dtype=torch.float).to(device)
output = model(input)
loss = F.mse_loss(input, output)
@ -802,7 +795,7 @@ class TestAutoWrap(TestCase):
@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
@parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
sequential = TestFSDPWrap.NestedSequentialModel.get_model(device=False)
sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
ignored_modules = [sequential[1], sequential[2][0]]
fsdp_kwargs = {
"process_group": self.process_group,
@ -827,7 +820,7 @@ class TestAutoWrap(TestCase):
@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")
@parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
sequential = TestFSDPWrap.NestedSequentialModel.get_model(device=False)
sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
ignored_modules = [sequential[1], sequential[2][0]]
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy,
@ -890,7 +883,7 @@ class TestAutoWrap(TestCase):
self._test_frozen_params(use_orig_params, policy)
def _test_frozen_params(self, use_orig_params: bool, policy: _Policy):
model = LoraModel().to(device=device_type)
model = LoraModel().cuda()
msg = "layers.0.attn has both parameters with requires_grad=True and False. "
if use_orig_params:
msg += "We do not recommend wrapping such modules"

View File

@ -1,250 +0,0 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import torch
import torch.distributed as dist
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
requires_cuda,
run_tests,
TestCase,
)
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.debug_mode import DebugMode
@requires_cuda
class TestDTensorDebugMode(TestCase):
def tearDown(self):
super().tearDown()
dist.destroy_process_group()
def setUp(self):
super().setUp()
self.world_size = 8
store = FakeStore()
dist.init_process_group(
backend="fake", rank=0, world_size=self.world_size, store=store
)
self.device_type = "cuda"
def test_debug_mode_mm(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
x = torch.randn(1, 8, requires_grad=False)
y = torch.randn(1, 32, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
with DebugMode() as debug_mode:
torch.mm(x_dtensor, y_dtensor).sum()
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
aten::mm(dt: f32[8, 8][S(0)], dt: f32[8, 32][S(0)])
redistribute_input(1, [S(0)] -> [R])
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
_c10d_functional::wait_tensor(t: f32[8, 32])
aten::mm(t: f32[1, 8], t: f32[8, 32])
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 32][S(0)])
aten::sum(dt: f32[8, 32][S(0)])
aten::sum(t: f32[1, 32])""",
)
def test_debug_string_inside_context(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
x = torch.randn(1, 8, requires_grad=False)
y = torch.randn(1, 32, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
with DebugMode() as debug_mode:
torch.mm(x_dtensor, y_dtensor).sum()
s0 = debug_mode.debug_string()
s1 = debug_mode.debug_string()
self.assertEqual(s0, s1)
def test_debug_mode_backward(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
x = torch.randn(1, 8, requires_grad=True)
y = torch.randn(8, 1, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(1)], run_check=False)
with DebugMode() as debug_mode:
z = x_dtensor + y_dtensor
z.sum().backward()
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
<method 'add' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
aten::add.Tensor(dt: f32[8, 8][S(0)], dt: f32[8, 8][S(1)])
redistribute_input(1, [S(1)] -> [S(0)])
_dtensor::shard_dim_alltoall(t: f32[8, 1], 1, 0, 0)
aten::add.Tensor(t: f32[1, 8], t: f32[1, 8])
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 8][S(0)])
aten::sum(dt: f32[8, 8][S(0)])
aten::sum(t: f32[1, 8])
torch._tensor.backward(dt: f32[][P], gradient=None, retain_graph=None, create_graph=False, inputs=None)
aten::ones_like(dt: f32[][P], pin_memory=False, memory_format=torch.preserve_format)
aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format)
aten::expand(dt: f32[][R], [8, 8])
aten::expand(t: f32[], [8, 8])
aten::split.Tensor(t: f32[8, 8], 1, 1)
aten::clone(t: f32[8, 1])
aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu)
aten::detach(t: f32[8, 1])
aten::split.Tensor(t: f32[8, 8], 1)
aten::clone(t: f32[1, 8])
aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu)
aten::detach(t: f32[1, 8])""",
)
def test_debug_mode_einsum(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).view(4, 2))
# Create test tensors
a = torch.randn(16, 6, 8)
b = torch.randn(8, 4, 4)
a_dt = DTensor.from_local(a, mesh, [Partial(), Replicate()], run_check=False)
b_dt = DTensor.from_local(b, mesh, [Replicate(), Partial()], run_check=False)
# Capture the operator decomposition
with DebugMode() as debug_mode:
torch.einsum("bld,dnh->blnh", a_dt, b_dt)
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8][P, R], dt: f32[8, 4, 4][R, P])
aten::unsqueeze(dt: f32[16, 6, 8][P, R], 3)
aten::unsqueeze(t: f32[16, 6, 8], 3)
aten::unsqueeze(dt: f32[16, 6, 8, 1][P, R], 4)
aten::unsqueeze(t: f32[16, 6, 8, 1], 4)
aten::permute(dt: f32[16, 6, 8, 1, 1][P, R], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2])
aten::unsqueeze(dt: f32[8, 4, 4][R, P], 3)
aten::unsqueeze(t: f32[8, 4, 4], 3)
aten::unsqueeze(dt: f32[8, 4, 4, 1][R, P], 4)
aten::unsqueeze(t: f32[8, 4, 4, 1], 4)
aten::permute(dt: f32[8, 4, 4, 1, 1][R, P], [3, 4, 1, 2, 0])
aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0])
aten::permute(dt: f32[16, 6, 1, 1, 8][P, R], [0, 1, 4, 2, 3])
aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3])
aten::view(dt: f32[16, 6, 8, 1, 1][P, R], [1, 96, 8])
aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8])
aten::permute(dt: f32[1, 1, 4, 4, 8][R, P], [4, 2, 3, 0, 1])
aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1])
aten::view(dt: f32[8, 4, 4, 1, 1][R, P], [1, 8, 16])
aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16])
aten::bmm(dt: f32[1, 96, 8][P, R], dt: f32[1, 8, 16][R, P])
redistribute_input(0, [P, R] -> [S(2), S(2)])
aten::chunk(t: f32[1, 96, 8], 4, 2)
aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]'])
_c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1)
_c10d_functional::wait_tensor(t: f32[1, 96, 2])
aten::chunk(t: f32[1, 96, 2], 2, 2)
aten::clone(t: f32[1, 96, 1])
redistribute_input(1, [R, P] -> [S(1), S(1)])
aten::chunk(t: f32[1, 8, 16], 4, 1)
aten::clone(t: f32[1, 2, 16])
aten::chunk(t: f32[1, 2, 16], 2, 1)
aten::cat(['t: f32[1, 1, 16]', 't: f32[1, 1, 16]'])
_c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3)
_c10d_functional::wait_tensor(t: f32[1, 1, 16])
aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16])
aten::view(dt: f32[1, 96, 16][P, P], [16, 6, 1, 4, 4])
aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4])
aten::permute(dt: f32[16, 6, 1, 4, 4][P, P], [0, 1, 3, 4, 2])
aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2])
aten::view(dt: f32[16, 6, 4, 4, 1][P, P], [16, 6, 4, 4])
aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])""",
)
def test_real_tensor(self):
x = torch.randn(8, 8, 8)
linear = torch.nn.Linear(8, 8)
with DebugMode() as debug_mode:
linear(x).sum()
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch._C._nn.linear(t: f32[8, 8, 8], t: f32[8, 8], t: f32[8])
aten::view(t: f32[8, 8, 8], [64, 8])
aten::t(t: f32[8, 8])
aten::addmm(t: f32[8], t: f32[64, 8], t: f32[8, 8])
aten::view(t: f32[64, 8], [8, 8, 8])
<method 'sum' of 'torch._C.TensorBase' objects>(t: f32[8, 8, 8])
aten::sum(t: f32[8, 8, 8])""",
)
def test_fake_tensor(self):
with FakeTensorMode():
x = torch.randn(8, 8)
y = torch.randn(8, 8, 8)
with DebugMode(record_faketensor=True) as debug_mode:
torch.matmul(y, x)
self.assertExpectedInline(
debug_mode.debug_string(),
"""\
torch.matmul(ft: f32[8, 8, 8], ft: f32[8, 8])
aten::view(ft: f32[8, 8, 8], [64, 8])
aten::mm(ft: f32[64, 8], ft: f32[8, 8])
aten::_unsafe_view(ft: f32[64, 8], [8, 8, 8])""",
)
@parametrize("has_inner_mode", [True, False])
@parametrize("has_outer_mode", [True, False])
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):
class DummyTorchDispatchMode1(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
class DummyTorchDispatchMode2(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
x = torch.randn(1, 8, requires_grad=True)
y = torch.randn(1, 32, requires_grad=True)
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
inner_mode = (
DummyTorchDispatchMode1() if has_inner_mode else contextlib.nullcontext()
)
outer_mode = (
DummyTorchDispatchMode2() if has_outer_mode else contextlib.nullcontext()
)
with outer_mode:
with DebugMode() as debug_mode:
with inner_mode:
torch.mm(x_dtensor, y_dtensor)
self.assertTrue(
"redistribute_input(1, [S(0)] -> [R])" in debug_mode.debug_string()
)
instantiate_parametrized_tests(TestDTensorDebugMode)
if __name__ == "__main__":
run_tests()

View File

@ -28,7 +28,6 @@ from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
AuxRequest,
create_block_mask,
flex_attention,
)
@ -575,8 +574,8 @@ class RingFlexAttentionTest(DTensorTestBase):
device=self.device_type,
)
expect_out, expect_aux = compiled_flex_attention(
q, k, v, block_mask=block_mask, return_aux=AuxRequest(lse=True)
expect_out, expect_lse = compiled_flex_attention(
q, k, v, block_mask=block_mask, return_lse=True
)
expect_out.sum().backward()
@ -636,12 +635,12 @@ class RingFlexAttentionTest(DTensorTestBase):
cp_k.requires_grad = True
cp_v.requires_grad = True
cp_out, cp_aux = compiled_flex_attention(
cp_out, cp_lse = compiled_flex_attention(
cp_q,
cp_k,
cp_v,
block_mask=cp_block_mask,
return_aux=AuxRequest(lse=True),
return_lse=True,
)
# check block_mask rewrite doesn't escape to the outside
@ -658,11 +657,9 @@ class RingFlexAttentionTest(DTensorTestBase):
cp_v.requires_grad = False
# unshard the output
cp_out, cp_lse = context_parallel_unshard(
device_mesh, [cp_out, cp_aux.lse], [2, 2]
)
cp_out, cp_lse = context_parallel_unshard(device_mesh, [cp_out, cp_lse], [2, 2])
torch.testing.assert_close(cp_out, expect_out, atol=atol, rtol=rtol)
torch.testing.assert_close(cp_lse, expect_aux.lse, atol=atol, rtol=rtol)
torch.testing.assert_close(cp_lse, expect_lse, atol=atol, rtol=rtol)
# unshard the gradient
cp_q_grad, cp_k_grad, cp_v_grad = context_parallel_unshard(

View File

@ -211,8 +211,8 @@ def forward(self, b_parametrizations_buffer_original0, x):
_assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None
_to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None
view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None
add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None
add_1 = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None
view_1 = torch.ops.aten.view.default(add_1, [4, 4]); add_1 = None
return (view_1,)""", # noqa: B950
)
@ -317,9 +317,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
self.assertEqual(res, ref)
@skipIfHpu
@unittest.skip(
"DTensor + dynamic fails - s77 + 8 is not tracked with proxy .. proxy_tensor.PythonKeyTracer"
)
def test_dtensor_dynamic_slice(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@ -361,9 +358,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
res = opt_fn(x)
self.assertEqual(res, ref)
@unittest.skip(
"DTensor + dynamic fails - s77 + 8 is not tracked with proxy .. proxy_tensor.PythonKeyTracer"
)
def test_dtensor_dynamic_cat(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -7,7 +7,7 @@ import warnings
import torch
import torch.distributed as dist
import torch.testing._internal.common_methods_invocations as common_ops
from torch.distributed.tensor import DTensor, init_device_mesh
from torch.distributed.tensor import DeviceMesh, DTensor
from torch.overrides import resolve_name
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
@ -159,6 +159,7 @@ dtensor_fails = {
xfail("geometric"),
xfail("geqrf"),
xfail("grid_sampler_2d"),
xfail("gradient"),
xfail("heaviside"),
xfail("histogram"),
xfail("histogramdd"),
@ -417,6 +418,8 @@ dtensor_fails = {
xfail("tensor_split"),
xfail("to_sparse"),
xfail("trace"),
xfail("trapezoid"),
xfail("trapz"),
xfail("triangular_solve"),
xfail("unbind"),
xfail("unbind_copy"),
@ -508,7 +511,7 @@ class TestDTensorOps(DTensorOpTestBase):
def run_opinfo_test(
self, dtype, op, requires_grad=True, sample_inputs_filter=lambda s: True
):
self.mesh = init_device_mesh(DEVICE_TYPE, (self.world_size,))
self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size))
# test each op with dist tensor inputs and normal inputs
def test():
@ -633,7 +636,7 @@ class TestDTensorOps(DTensorOpTestBase):
)
except Exception as e:
raise RuntimeError(
f"{str(e)}\n\nfailed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
f"failed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
) from e
return rs

View File

@ -43,7 +43,6 @@ from torch.testing._internal.common_utils import (
retry_on_connect_failures,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TEST_XPU,
TestCase,
)
from torch.utils.checkpoint import checkpoint
@ -64,8 +63,6 @@ else:
torch.backends.cuda.matmul.allow_tf32 = False
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
def gpus_for_rank(world_size):
"""Multigpu tests are designed to simulate the multi nodes with multi
@ -73,9 +70,8 @@ def gpus_for_rank(world_size):
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
device_count = torch.accelerator.device_count()
visible_devices = list(range(device_count))
gpus_per_process = device_count // world_size
visible_devices = list(range(torch.cuda.device_count()))
gpus_per_process = torch.cuda.device_count() // world_size
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
@ -405,7 +401,7 @@ class CommonDistributedDataParallelTest:
gradient_as_bucket_view=gradient_as_bucket_view,
)
input = torch.randn(global_batch_size, 2).to(devices[0])
input = torch.randn(global_batch_size, 2).cuda(devices[0])
target = torch.randn(global_batch_size, 4)
return model, ddp_model, input, target
@ -439,10 +435,10 @@ class CommonDistributedDataParallelTest:
allow_none_grads=False,
):
# to reproduce the same training results
torch.accelerator.set_device_index(self.rank)
torch.cuda.set_device(self.rank)
torch.manual_seed(31415)
model = copy.deepcopy(input_model).to(device_type)
ddp_model = copy.deepcopy(input_model).to(device_type)
model = copy.deepcopy(input_model).cuda()
ddp_model = copy.deepcopy(input_model).cuda()
ddp_model = nn.parallel.DistributedDataParallel(
ddp_model,
bucket_cap_mb=1,
@ -558,8 +554,8 @@ class CommonDistributedDataParallelTest:
def _prepare_dummy_data(self):
ddp_bs = 16
bs = ddp_bs * self.world_size
input = torch.rand((bs, 20), device=device_type, requires_grad=True)
target = torch.randn((bs, 20), device=device_type)
input = torch.rand((bs, 20), device="cuda", requires_grad=True)
target = torch.randn((bs, 20), device="cuda")
offset = self.rank * ddp_bs
ddp_input = input[offset : offset + ddp_bs]
ddp_target = target[offset : offset + ddp_bs]
@ -719,7 +715,7 @@ class CommonDistributedDataParallelTest:
Test that checkpointing with weight sharing works.
"""
process_group = self._get_process_group()
torch.accelerator.set_device_index(self.rank)
torch.cuda.set_device(self.rank)
for use_bucket_view, static_graph in product((False, True), (False, True)):
torch.manual_seed(31415)
l1 = nn.Linear(20, 20)
@ -742,7 +738,7 @@ class CommonDistributedDataParallelTest:
same layer twice and having weights shared across layers.
"""
process_group = self._get_process_group()
torch.accelerator.set_device_index(self.rank)
torch.cuda.set_device(self.rank)
for use_bucket_view in (True, False):
self._test_ddp_checkpointing(
self.CheckpointTwiceModuleWeightSharing(),
@ -1166,7 +1162,7 @@ class AbstractCommTest:
# Verify sequence numbers are appropriately incremented
for i in range(10):
t = torch.ones(1, device=device_type)
t = torch.ones(1, device=torch.cuda.current_device())
dist.all_reduce(t, group=process_group)
if not c10d._rank_not_in_group(process_group):
seq_num = self._verify_sequence_number_across_pg(
@ -1197,7 +1193,7 @@ class AbstractCommTest:
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
def _test_sequence_num_incremented_default_group(self, backend_name):
torch.accelerator.set_device_index(self.rank)
torch.cuda.set_device(self.rank)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend_name,
@ -1211,7 +1207,7 @@ class AbstractCommTest:
)
def _test_sequence_num_incremented_subgroup(self, backend_name):
torch.accelerator.set_device_index(self.rank)
torch.cuda.set_device(self.rank)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend_name,
@ -1266,8 +1262,8 @@ class AbstractCommTest:
in_group_ranks = list(filter(lambda x: x % 2 == 0, range(self.world_size)))
group = dist.new_group(in_group_ranks)
x = torch.zeros(2, 2).to(self.rank)
xs = [torch.zeros(2, 2).to(self.rank) for _ in range(len(in_group_ranks))]
x = torch.zeros(2, 2).cuda(self.rank)
xs = [torch.zeros(2, 2).cuda(self.rank) for _ in range(len(in_group_ranks))]
if self.rank not in in_group_ranks:
msg = ".*{}.*does not belong to.*"
with self.assertWarnsOnceRegex(UserWarning, msg.format("all_gather")):
@ -1396,7 +1392,7 @@ class AbstractCommTest:
rank=self.rank,
store=store,
)
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
device = "cuda" if backend == "nccl" else "cpu"
# test alltoall_base
tensor = torch.tensor([1, 0, 0, 1], dtype=torch.bool, device=device)
zeros = torch.tensor([0, 0, 0, 0], dtype=torch.bool, device=device)
@ -1578,8 +1574,8 @@ class CommTest(AbstractCommTest, MultiProcessTestCase):
class DummyWork(dist._Work):
def wait(self, timeout=5.0):
if torch.accelerator.is_available():
torch.accelerator.current_stream().synchronize()
if torch.cuda.is_available():
torch.cuda.current_stream().synchronize()
return True
@ -1794,18 +1790,6 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
("cpu:gloo,cuda:nccl", "cpu:gloo,cuda:nccl"),
]
if TEST_XPU:
# Override backend_config_strings_and_expected_values for Intel GPU.
backend_config_strings_and_expected_values[4:10] = [
(dist.Backend.DUMMY, "cpu:dummy,cuda:dummy,xpu:dummy"),
("DUMMY", "cpu:dummy,cuda:dummy,xpu:dummy"),
("dummy", "cpu:dummy,cuda:dummy,xpu:dummy"),
("cpu:dummy,xpu:dummy", "cpu:dummy,xpu:dummy"),
("cpu:dummy,xpu:xccl", "cpu:dummy,xpu:xccl"),
("cpu:gloo,xpu:dummy", "cpu:gloo,xpu:dummy"),
("cpu:gloo,xpu:xccl", "cpu:gloo,xpu:xccl"),
]
for config_str, expected_value in backend_config_strings_and_expected_values:
with self.subTest(config_str):
# ensures these configs strings are valid and no ValueError is raised
@ -1816,8 +1800,6 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
invalid_backend_config_strings = [
"cpu:gloo,cuda:nccl,", # trailing comma
"cpu:gloo,cuda:nccl,cpu:dummy", # duplicate device
"cpu:gloo,xpu:xccl,", # trailing comma
"cpu:gloo,xpu:xccl,cpu:dummy", # duplicate device
]
for config_str in invalid_backend_config_strings:
with self.subTest(config_str):
@ -1832,7 +1814,7 @@ class PythonProcessGroupExtensionTest(MultiProcessTestCase):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "6789"
dist.init_process_group(
"cpu:dummy,cuda:dummy,xpu:dummy", rank=self.rank, world_size=self.world_size
"cpu:dummy,cuda:dummy", rank=self.rank, world_size=self.world_size
)
# test all_gather
@ -2071,7 +2053,7 @@ dist.init_process_group(rank=0, world_size=1, store=dist.HashStore())
# correctly dispatched
# TODO: this will be updated in the future to not be backend specific
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
device = "cuda" if backend == "nccl" else "cpu"
# ensure supported devices (cpu, cuda) succeeds during dispatch call
tensor = torch.zeros(2, 2, device=torch.device(device))
# multi tensor collectives
@ -2137,7 +2119,7 @@ dist.init_process_group(rank=0, world_size=1, store=dist.HashStore())
rank=self.rank,
store=store,
)
device = "cuda" if backend == "nccl" else "xpu" if backend == "xccl" else "cpu"
device = "cuda" if backend == "nccl" else "cpu"
# test alltoall_base
input_tensor = torch.ones(2, 2, device=torch.device(device))
output_tensor = torch.zeros(2, 2, device=torch.device(device))
@ -2269,9 +2251,8 @@ class LocalRankTest(MultiProcessTestCase):
if __name__ == "__main__":
if device_type != "cpu":
assert not torch.get_device_module()._initialized, (
"test_distributed must not have initialized {device_type} context on main process"
)
assert not torch.cuda._initialized, (
"test_distributed must not have initialized CUDA context on main process"
)
run_tests()

View File

@ -21,15 +21,15 @@ from torch.distributed._functional_collectives import (
reduce_scatter_tensor,
reduce_scatter_tensor_coalesced,
)
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
from torch.testing._internal.common_device_type import e4m3_type
from torch.testing._internal.common_cuda import SM90OrLater
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_accelerator_dist_backend,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
run_tests,
skipIfRocm,
TestCase,
)
from torch.testing._internal.distributed.fake_pg import FakeStore
@ -59,7 +59,7 @@ if not dist.is_available():
sys.exit(0)
@requires_accelerator_dist_backend(["nccl", "xccl"])
@requires_nccl()
class TestWithNCCL(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
@ -75,15 +75,13 @@ class TestWithNCCL(MultiProcessTestCase):
@property
def device(self) -> torch.device:
return torch.device(self.rank)
return torch.device(f"cuda:{self.rank}")
def _init_process_group(self) -> None:
torch.accelerator.set_device_index(self.rank)
torch.cuda.set_device(self.device)
store = dist.FileStore(self.file_name, self.world_size)
backend = dist.get_default_backend_for_device(self.device.type)
dist.init_process_group(
backend=backend,
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
@ -275,7 +273,7 @@ class TestWithNCCL(MultiProcessTestCase):
)
# check memory leak
for i in range(1, 10):
mem_usage[i] = torch.accelerator.max_memory_allocated()
mem_usage[i] = torch.cuda.max_memory_allocated()
compiled(arg)
assert mem_usage[9] == mem_usage[8]
@ -372,16 +370,14 @@ class TestWithNCCL(MultiProcessTestCase):
@skip_if_lt_x_gpu(2)
def test_all_to_all_single(self) -> None:
self._init_process_group()
torch.accelerator.set_device_index(self.rank)
torch.cuda.set_device(self.device)
torch.manual_seed(42)
send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))
input_split_sizes = send_sz_matrix[self.rank].tolist()
output_split_sizes = send_sz_matrix[:, self.rank].tolist()
input = torch.full((sum(input_split_sizes),), float(self.rank)).to(
self.device.type
)
input = torch.full((sum(input_split_sizes),), float(self.rank)).cuda()
output = torch.ops._c10d_functional.all_to_all_single(
input,
@ -392,7 +388,7 @@ class TestWithNCCL(MultiProcessTestCase):
output = torch.ops._c10d_functional.wait_tensor(output)
expect = torch.cat(
[
torch.full((sz,), float(rank)).to(self.device.type)
torch.full((sz,), float(rank)).cuda()
for rank, sz in enumerate(output_split_sizes)
]
)
@ -468,7 +464,7 @@ class TestWithNCCL(MultiProcessTestCase):
@fresh_cache()
def test_threading(self):
self._init_process_group()
device = self.device
device = torch.device(f"cuda:{self.rank}")
def func(arg: torch.Tensor) -> torch.Tensor:
buf0 = arg + 42
@ -505,9 +501,10 @@ class TestWithNCCL(MultiProcessTestCase):
t.start()
t.join()
@skipIfRocm
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"_scaled_mm currently only supports sm>=90 on cuda and gfx94/95 on ROCm",
not SM90OrLater,
"_scaled_mm currently only supports sm>=90",
)
@skip_if_lt_x_gpu(2)
@fresh_cache()
@ -516,9 +513,10 @@ class TestWithNCCL(MultiProcessTestCase):
def scale(t):
scale = (
torch.finfo(e4m3_type).max / t.abs().amax(dim=-1, keepdim=True).float()
torch.finfo(torch.float8_e4m3fn).max
/ t.abs().amax(dim=-1, keepdim=True).float()
)
t = t.mul(scale).to(e4m3_type)
t = t.mul(scale).to(torch.float8_e4m3fn)
return t, scale
def fp8_rowwise_backward(in_, w, out_grad):
@ -548,9 +546,9 @@ class TestWithNCCL(MultiProcessTestCase):
return in_grad, w_grad
m, n, k = 128, 256, 64
in_ = torch.randn((m, k), device=self.device.type, dtype=torch.bfloat16)
w = torch.randn((n, k), device=self.device.type, dtype=torch.bfloat16)
out_grad = torch.randn((m, n), device=self.device.type, dtype=torch.bfloat16)
in_ = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
w = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
out_grad = torch.randn((m, n), device="cuda", dtype=torch.bfloat16)
eager_in_grad, eager_w_grad = fp8_rowwise_backward(in_, w, out_grad)
compile_in_grad, compile_w_grad = torch.compile(fp8_rowwise_backward)(
@ -779,8 +777,7 @@ class CompileTest(TestCase):
self.rank = 0
self.world_size = 2
torch.accelerator.set_device_index(0)
self.device = torch.accelerator.current_accelerator()
torch.cuda.set_device("cuda:0")
store = FakeStore()
dist.init_process_group(
@ -806,7 +803,7 @@ class CompileTest(TestCase):
ar1 = funcol.wait_tensor(ar1)
return ar0, ar1
arg = torch.rand(4, 4, device=self.device)
arg = torch.rand(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
@ -839,7 +836,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.accelerator.synchronize()
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -854,7 +851,7 @@ class CompileTest(TestCase):
ar1 = [funcol.wait_tensor(out) for out in ar1]
return ar0, ar1
args = [torch.rand(4, 4, device=self.device.type) for _ in range(2)]
args = [torch.rand(4, 4, device="cuda") for _ in range(2)]
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, args)
buf0, buf1, buf2, buf3 = find_buffer_assignments(code)
@ -884,7 +881,7 @@ class CompileTest(TestCase):
# Test aoti
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
torch.accelerator.synchronize()
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -895,7 +892,7 @@ class CompileTest(TestCase):
ar0 = funcol.wait_tensor(ar0)
return ar0
arg = torch.rand(4, 4, device=self.device.type)
arg = torch.rand(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
@ -920,7 +917,7 @@ class CompileTest(TestCase):
# Expect allocation
return ar0
arg = torch.rand(4, 4, device=self.device.type).T
arg = torch.rand(4, 4, device="cuda").T
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
@ -951,7 +948,7 @@ class CompileTest(TestCase):
buf2 = torch.mm(arg, buf1)
return buf1, buf2
arg = torch.rand(4, 4, device=self.device.type)
arg = torch.rand(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
buf0, buf1 = find_buffer_assignments(code)
@ -981,7 +978,7 @@ class CompileTest(TestCase):
ag0 = funcol.wait_tensor(ag0)
return ag0
arg = torch.rand(4, 4, device=self.device.type)
arg = torch.rand(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
@ -998,7 +995,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.accelerator.synchronize()
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1008,7 +1005,7 @@ class CompileTest(TestCase):
ag0 = [funcol.wait_tensor(out) for out in ag0]
return ag0
args = [torch.rand(4, 4, device=self.device.type) for _ in range(4)]
args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, args)
(
@ -1032,7 +1029,7 @@ class CompileTest(TestCase):
# Test aoti
out = AOTIRunnerUtil.run(func, (args,)) # noqa: F841
torch.accelerator.synchronize()
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "This is a GPU test!")
@fresh_cache()
@ -1042,7 +1039,7 @@ class CompileTest(TestCase):
return funcol.wait_tensor(t)
# Test aoti
arg = torch.rand(4, 4, device=self.device.type)
arg = torch.rand(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
@ -1054,7 +1051,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.accelerator.synchronize()
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1064,7 +1061,7 @@ class CompileTest(TestCase):
rs0 = funcol.wait_tensor(rs0)
return rs0
arg = torch.rand(4, 4, device=self.device.type)
arg = torch.rand(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
(
@ -1080,7 +1077,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.accelerator.synchronize()
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1092,7 +1089,7 @@ class CompileTest(TestCase):
rs0 = [funcol.wait_tensor(out) for out in rs0]
return rs0
args = [torch.rand(4, 4, device=self.device.type) for _ in range(4)]
args = [torch.rand(4, 4, device="cuda") for _ in range(4)]
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, args)
(
@ -1116,7 +1113,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (args,))
torch.accelerator.synchronize()
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1145,9 +1142,7 @@ class CompileTest(TestCase):
input_split_sizes = send_sz_matrix[self.rank]
output_split_sizes = send_sz_matrix[:, self.rank].contiguous()
input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).to(
self.device.type
)
input = torch.full((input_split_sizes.sum().item(),), float(self.rank)).cuda()
with torch._dynamo.config.patch(
dynamic_shapes=True,
@ -1181,7 +1176,7 @@ class CompileTest(TestCase):
br1 = funcol.wait_tensor(br1)
return br0, br1
arg = torch.rand(4, 4, device=self.device.type)
arg = torch.rand(4, 4, device="cuda")
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, arg)
@ -1204,7 +1199,7 @@ class CompileTest(TestCase):
# Test aoti
AOTIRunnerUtil.run(func, (arg,))
torch.accelerator.synchronize()
torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_cache()
@ -1219,7 +1214,7 @@ class CompileTest(TestCase):
ar1 = funcol.wait_tensor(ar1)
return ar0, ar1
arg = torch.rand(4, 4, device=self.device.type)
arg = torch.rand(4, 4, device="cuda")
compiled = torch.compile(func, fullgraph=True)
code = run_and_get_triton_code(compiled, arg)

View File

@ -3145,24 +3145,19 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
nccl_debug_file = tempfile.NamedTemporaryFile()
nccl_env = {
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "1",
"NCCL_ALGO": "NVLS",
"NCCL_DEBUG": "INFO",
"NCCL_DEBUG_SUBSYS": "NVLS",
"NCCL_DEBUG_FILE": nccl_debug_file.name,
}
os.environ["NCCL_ALGO"] = "NVLS"
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "NVLS"
if torch.cuda.nccl.version() >= (2, 24, 3):
nccl_env["NCCL_DEBUG_SUBSYS"] = "REG,TUNING"
self.env_patcher = mock.patch.dict(os.environ, nccl_env)
self.env_patcher.start()
os.environ["NCCL_DEBUG_SUBSYS"] = "REG,TUNING"
os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name
self._spawn_processes()
def tearDown(self):
self.env_patcher.stop()
super().tearDown()
try:
os.remove(self.file_name)

View File

@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import unittest
import torch
import torch.distributed as dist
@ -27,7 +26,7 @@ from torch.distributed.tensor._collective_utils import (
)
from torch.distributed.tensor.placement_types import _Partial, Shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_XPU
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
@ -36,10 +35,6 @@ from torch.testing._internal.distributed.fake_pg import FakeProcessGroup, FakeSt
from torch.utils._typing_utils import not_none
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
device_count = torch.accelerator.device_count()
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1):
os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = port
@ -49,7 +44,6 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_ran
os.environ["LOCAL_RANK"] = f"{local_rank}"
@unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.")
class DeviceMeshTestGlooBackend(DTensorTestBase):
@property
def backend(self):
@ -79,16 +73,14 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
# Set the device on each process before DeviceMesh constructor,
# and device to be different than the default world rank
torch.accelerator.set_device_index((self.rank + 2) % self.world_size)
torch.cuda.set_device((self.rank + 2) % self.world_size)
_set_env_var(world_size=self.world_size, rank=self.rank)
DeviceMesh(self.device_type, mesh_tensor)
self.assertTrue(is_initialized())
# check that the device is set to the correct device
# and respect the previous set_device calls
self.assertEqual(
torch.accelerator.current_device_idx(), (self.rank + 2) % self.world_size
)
self.assertEqual(torch.cuda.current_device(), (self.rank + 2) % self.world_size)
self.destroy_pg()
@skip_if_lt_x_gpu(4)
@ -109,7 +101,7 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
# check that the device is set to the correct device
# and respect the LOCAL_RANK env var
self.assertEqual(torch.accelerator.current_device_idx(), local_rank)
self.assertEqual(torch.cuda.current_device(), local_rank)
self.destroy_pg()
@skip_if_lt_x_gpu(4)
@ -128,7 +120,7 @@ class DeviceMeshSetDeviceTest(DTensorTestBase):
self.assertTrue(is_initialized())
# check that the device is set to the correct device
self.assertEqual(torch.accelerator.current_device_idx(), self.rank)
self.assertEqual(torch.cuda.current_device(), self.rank)
self.destroy_pg()
@ -230,7 +222,7 @@ class DeviceMeshTest(DTensorTestBase):
@with_comms
def test_device_mesh_2d(self):
mesh_tensor = torch.arange(4).reshape(2, 2)
# construct a device mesh for self.device_type
# construct a cuda device mesh
mesh = DeviceMesh(self.device_type, mesh_tensor)
# check all dim groups
@ -268,11 +260,7 @@ class DeviceMeshTest(DTensorTestBase):
def test_fake_pg_device_mesh(self):
fake_store = FakeStore()
init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
device_type = (
torch.accelerator.current_accelerator().type
if torch.accelerator.is_available()
else "cpu"
)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
mesh = DeviceMesh(device_type, torch.arange(self.world_size))
local_tensor = torch.randn(2, 8)
@ -312,7 +300,7 @@ class DeviceMeshTest(DTensorTestBase):
regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]"
with self.assertRaisesRegex(ValueError, regex):
DeviceMesh.from_group(
global_pg, device_type, invalid_mesh, mesh_dim_names=("dim0", "dim1")
global_pg, "cuda", invalid_mesh, mesh_dim_names=("dim0", "dim1")
)
device_mesh = init_device_mesh(self.device_type, (2, 2))
@ -332,16 +320,12 @@ class DeviceMeshTest(DTensorTestBase):
# test init_device_mesh with an invalid device type that contains a GPU index
mesh_shape = (2, self.world_size // 2)
init_device_mesh(
f"{device_type}:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
"cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
)
@with_comms
def test_set_mesh_dim_group_options(self):
device_type = (
torch.accelerator.current_accelerator().type
if torch.accelerator.is_available()
else "cpu"
)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
_mesh_resources._set_mesh_dim_group_options(1, "fake", None)
mesh_tensor = torch.arange(4).reshape(2, 2)
@ -357,7 +341,7 @@ class DeviceMeshTestNDim(DTensorTestBase):
@with_comms
def test_device_mesh_nd(self):
# construct a device mesh for self.device_type
# construct a cuda device mesh
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
@ -726,9 +710,7 @@ class TestDeviceMeshGetItem(DTensorTestBase):
with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"):
mesh_dim_names = ("DP", "TP")
mesh = init_device_mesh(
self.device_type,
(2, 4),
mesh_dim_names=mesh_dim_names,
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
)
mesh[child_mesh_dim_name]
@ -956,9 +938,7 @@ class TestMeshEnv(DTensorTestBase):
@with_comms
def test_get_root_mesh(self):
mesh_3d = init_device_mesh(
self.device_type,
(2, 2, 2),
mesh_dim_names=("dp", "cp", "tp"),
self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp")
)
dp_cp_mesh = mesh_3d["dp", "cp"]
@ -1006,9 +986,7 @@ class TestMeshEnv(DTensorTestBase):
@with_comms
def test_get_all_submeshes(self):
mesh_2d = init_device_mesh(
self.device_type,
(2, 4),
mesh_dim_names=("replicate", "shard"),
self.device_type, (2, 4), mesh_dim_names=("replicate", "shard")
)
all_submeshes = _mesh_resources._get_all_submeshes(mesh_2d, "replicate")
self.assertEqual(len(all_submeshes), 4)

Some files were not shown because too many files have changed in this diff Show More