mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 23:45:05 +08:00
Compare commits
38 Commits
mlazos/evt
...
starterTas
| Author | SHA1 | Date | |
|---|---|---|---|
| 19e09ed25d | |||
| c54b9f2969 | |||
| be36bacdaa | |||
| 6487ea30b3 | |||
| b0e5402377 | |||
| ee72c53c88 | |||
| ed5f4a4fa8 | |||
| 0ec8fe46d7 | |||
| dccd19c2ef | |||
| 7a46f4bde0 | |||
| c5cba39d46 | |||
| 3cd5b3b1e7 | |||
| 02af4e88e4 | |||
| c45515c2ed | |||
| 4f1a52fba4 | |||
| f3daedb263 | |||
| 5506baa4ed | |||
| 6f835a4769 | |||
| 2ade886412 | |||
| 1bc5762495 | |||
| 74d0300804 | |||
| 694748dd9d | |||
| 6fe5d9215f | |||
| a2d0ef242d | |||
| a8986963da | |||
| 75eb2f3ff6 | |||
| cb57b19c3a | |||
| aa84c037f0 | |||
| 68034198e5 | |||
| 0e805aad7f | |||
| 8568dbce1d | |||
| 27f7b65a69 | |||
| 7ebea09986 | |||
| f604732e2e | |||
| b4fb801b2d | |||
| 40339c1e99 | |||
| 9b2a45ac7d | |||
| e802b29ed4 |
@ -23,6 +23,12 @@ inputs:
|
||||
type: string
|
||||
description: 'the job name of the test'
|
||||
required: True
|
||||
artifact_prefix:
|
||||
type: string
|
||||
description: |
|
||||
'the prefix of the raw utilization data, for data stored in zip file, this is the prefix of the parent zip file'
|
||||
default: ""
|
||||
required: False
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
@ -35,6 +41,7 @@ runs:
|
||||
echo "workflow_Name: ${{inputs.workflow_name}}"
|
||||
echo "job_id: ${{inputs.job_id}}"
|
||||
echo "job_name: ${{inputs.job_name}}"
|
||||
echo "artifact_prefix: ${{inputs.artifact_prefix}}"
|
||||
- uses: nick-fields/retry@v3.0.0
|
||||
name: Setup dependencies
|
||||
with:
|
||||
@ -53,4 +60,5 @@ runs:
|
||||
--workflow-name "${{inputs.workflow_name}}" \
|
||||
--workflow-run-attempt "${{inputs.workflow_attempt}}" \
|
||||
--job-id "${{inputs.job_id}}" \
|
||||
--job-name "${{inputs.job_name}}"
|
||||
--job-name "${{inputs.job_name}}" \
|
||||
--artifact-prefix "${{inputs.artifact_prefix}}"
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
8d9e34b352af09c81ff8df448fd27f9c4aae1382
|
||||
edc1a882d872dd7f1362e4312fd045a1d81b3355
|
||||
|
||||
79
.github/workflows/_linux-build.yml
vendored
79
.github/workflows/_linux-build.yml
vendored
@ -74,6 +74,24 @@ on:
|
||||
Overwrite the number of jobs to use for the build
|
||||
required: false
|
||||
type: string
|
||||
disable-monitor:
|
||||
description: |
|
||||
Disable utilization monitoring for build job
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
monitor-log-interval:
|
||||
description: |
|
||||
Set the interval for the monitor script to log utilization.
|
||||
required: false
|
||||
type: number
|
||||
default: 5
|
||||
monitor-data-collect-interval:
|
||||
description: |
|
||||
Set the interval for the monitor script to collect data.
|
||||
required: false
|
||||
type: number
|
||||
default: 1
|
||||
|
||||
secrets:
|
||||
HUGGING_FACE_HUB_TOKEN:
|
||||
@ -176,6 +194,27 @@ jobs:
|
||||
selected-test-configs: ${{ inputs.selected-test-configs }}
|
||||
job-name: ${{ steps.get-job-id.outputs.job-name }}
|
||||
|
||||
- name: Start monitoring script
|
||||
id: monitor-script
|
||||
if: ${{ !inputs.disable-monitor }}
|
||||
shell: bash
|
||||
continue-on-error: true
|
||||
env:
|
||||
JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
|
||||
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
|
||||
WORKFLOW_NAME: ${{ github.workflow }}
|
||||
WORKFLOW_RUN_ID: ${{github.run_id}}
|
||||
MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }}
|
||||
MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }}
|
||||
run: |
|
||||
mkdir -p ../../usage_logs
|
||||
python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7
|
||||
python3 -m tools.stats.monitor \
|
||||
--log-interval "$MONITOR_LOG_INTERVAL" \
|
||||
--data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" \
|
||||
> "../../usage_logs/usage_log_build_${JOB_ID}.txt" 2>&1 &
|
||||
echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
- name: Download pytest cache
|
||||
uses: ./.github/actions/pytest-cache-download
|
||||
continue-on-error: true
|
||||
@ -280,6 +319,15 @@ jobs:
|
||||
END_TIME=$(date +%s)
|
||||
echo "build_time=$((END_TIME - START_TIME))" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Stop monitoring script
|
||||
if: ${{ always() && steps.monitor-script.outputs.monitor-script-pid }}
|
||||
shell: bash
|
||||
continue-on-error: true
|
||||
env:
|
||||
MONITOR_SCRIPT_PID: ${{ steps.monitor-script.outputs.monitor-script-pid }}
|
||||
run: |
|
||||
kill "$MONITOR_SCRIPT_PID"
|
||||
|
||||
- name: Archive artifacts into zip
|
||||
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
|
||||
run: |
|
||||
@ -304,6 +352,25 @@ jobs:
|
||||
if-no-files-found: error
|
||||
path: artifacts.zip
|
||||
|
||||
- name: copy logs
|
||||
shell: bash
|
||||
if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor && inputs.build-environment != 'linux-s390x-binary-manywheel'}}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
rm -f ./usage_logs
|
||||
mkdir -p ./usage_logs
|
||||
cp ../../usage_logs/usage_log_build_*.txt ./usage_logs/
|
||||
|
||||
- name: Upload raw usage log to s3
|
||||
if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor && inputs.build-environment != 'linux-s390x-binary-manywheel'}}
|
||||
uses: seemethere/upload-artifact-s3@v5
|
||||
with:
|
||||
s3-prefix: |
|
||||
${{ github.repository }}/${{ github.run_id }}/${{ github.run_attempt }}/artifact
|
||||
retention-days: 14
|
||||
if-no-files-found: warn
|
||||
path: usage_logs/usage_log_build_*.txt
|
||||
|
||||
- name: Upload sccache stats
|
||||
if: steps.build.outcome != 'skipped' && inputs.build-environment != 'linux-s390x-binary-manywheel'
|
||||
uses: ./.github/actions/upload-sccache-stats
|
||||
@ -311,6 +378,18 @@ jobs:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
build-time: ${{ steps.build.outputs.build_time }}
|
||||
|
||||
- name: Upload utilization stats
|
||||
if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/upload-utilization-stats
|
||||
with:
|
||||
job_id: ${{ steps.get-job-id.outputs.job-id }}
|
||||
job_name: ${{ steps.get-job-id.outputs.job-name }}
|
||||
workflow_name: ${{ github.workflow }}
|
||||
workflow_run_id: ${{github.run_id}}
|
||||
workflow_attempt: ${{github.run_attempt}}
|
||||
artifact_prefix: usage_log_build_${{ steps.get-job-id.outputs.job-id }}
|
||||
|
||||
- name: Teardown Linux
|
||||
uses: pytorch/test-infra/.github/actions/teardown-linux@main
|
||||
if: always() && inputs.build-environment != 'linux-s390x-binary-manywheel'
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -47,6 +47,7 @@ docs/source/generated/
|
||||
docs/source/compile/generated/
|
||||
log
|
||||
usage_log.txt
|
||||
usage_log*
|
||||
test-reports/
|
||||
test/*.bak
|
||||
test/**/*.bak
|
||||
|
||||
@ -1521,7 +1521,7 @@ code = 'RUFF'
|
||||
include_patterns = [
|
||||
'**/*.py',
|
||||
'**/*.pyi',
|
||||
'torch/utils/data/*.ipynb',
|
||||
'**/*.ipynb',
|
||||
'pyproject.toml',
|
||||
]
|
||||
exclude_patterns = [
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
|
||||
## Demo applications and tutorials
|
||||
|
||||
Demo applications with code walk-through can be find in [this github repo](https://github.com/pytorch/android-demo-app).
|
||||
Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch).
|
||||
|
||||
Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions.
|
||||
|
||||
## Publishing
|
||||
|
||||
@ -119,8 +121,6 @@ We also have to add all transitive dependencies of our aars.
|
||||
As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them.
|
||||
(In case of using maven dependencies they are added automatically from `pom.xml`).
|
||||
|
||||
You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly.
|
||||
|
||||
## Linking to prebuilt libtorch library from gradle dependency
|
||||
|
||||
In some cases, you may want to use libtorch from your android native build.
|
||||
@ -202,7 +202,7 @@ find_library(FBJNI_LIBRARY fbjni
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
target_link_libraries(${PROJECT_NAME}
|
||||
${PYTORCH_LIBRARY})
|
||||
${PYTORCH_LIBRARY}
|
||||
${FBJNI_LIBRARY})
|
||||
|
||||
```
|
||||
@ -233,8 +233,6 @@ void loadAndForwardModel(const std::string& modelPath) {
|
||||
|
||||
To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28)
|
||||
|
||||
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
|
||||
|
||||
## PyTorch Android API Javadoc
|
||||
|
||||
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/).
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -eux
|
||||
|
||||
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
|
||||
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
|
||||
|
||||
echo "PYTORCH_DIR:$PYTORCH_DIR"
|
||||
|
||||
source "$PYTORCH_ANDROID_DIR/common.sh"
|
||||
|
||||
check_android_sdk
|
||||
check_gradle
|
||||
parse_abis_list "$@"
|
||||
build_android
|
||||
|
||||
# To set proxy for gradle add following lines to ./gradle/gradle.properties:
|
||||
# systemProp.http.proxyHost=...
|
||||
# systemProp.http.proxyPort=8080
|
||||
# systemProp.https.proxyHost=...
|
||||
# systemProp.https.proxyPort=8080
|
||||
|
||||
if [ "$CUSTOM_ABIS_LIST" = true ]; then
|
||||
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
|
||||
else
|
||||
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
|
||||
fi
|
||||
|
||||
find $PYTORCH_ANDROID_DIR -type f -name *apk
|
||||
|
||||
find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r "
|
||||
@ -1,32 +0,0 @@
|
||||
#!/bin/bash
|
||||
###############################################################################
|
||||
# This script tests the custom selective build flow for PyTorch Android, which
|
||||
# optimizes library size by only including ops used by a specific model.
|
||||
###############################################################################
|
||||
|
||||
set -eux
|
||||
|
||||
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
|
||||
PYTORCH_ANDROID_DIR="${PYTORCH_DIR}/android"
|
||||
BUILD_ROOT="${PYTORCH_DIR}/build_pytorch_android_custom"
|
||||
|
||||
source "${PYTORCH_ANDROID_DIR}/common.sh"
|
||||
|
||||
prepare_model_and_dump_root_ops() {
|
||||
cd "${BUILD_ROOT}"
|
||||
MODEL="${BUILD_ROOT}/MobileNetV2.pt"
|
||||
ROOT_OPS="${BUILD_ROOT}/MobileNetV2.yaml"
|
||||
python "${PYTORCH_ANDROID_DIR}/test_app/make_assets_custom.py"
|
||||
cp "${MODEL}" "${PYTORCH_ANDROID_DIR}/test_app/app/src/main/assets/mobilenet2.pt"
|
||||
}
|
||||
|
||||
# Start building
|
||||
mkdir -p "${BUILD_ROOT}"
|
||||
check_android_sdk
|
||||
check_gradle
|
||||
parse_abis_list "$@"
|
||||
prepare_model_and_dump_root_ops
|
||||
SELECTED_OP_LIST="${ROOT_OPS}" build_android
|
||||
|
||||
# TODO: change this to build test_app instead
|
||||
$GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean assembleRelease
|
||||
@ -3,4 +3,3 @@ include ':app', ':pytorch_android', ':pytorch_android_torchvision', ':pytorch_ho
|
||||
project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torchvision')
|
||||
|
||||
project(':pytorch_host').projectDir = file('pytorch_android/host')
|
||||
project(':test_app').projectDir = file('test_app/app')
|
||||
|
||||
9
android/test_app/.gitignore
vendored
9
android/test_app/.gitignore
vendored
@ -1,9 +0,0 @@
|
||||
local.properties
|
||||
**/*.iml
|
||||
.gradle
|
||||
gradlew*
|
||||
gradle/wrapper
|
||||
.idea/*
|
||||
.DS_Store
|
||||
build
|
||||
.externalNativeBuild
|
||||
@ -1,38 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
set(PROJECT_NAME pytorch_testapp_jni)
|
||||
project(${PROJECT_NAME} CXX)
|
||||
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
|
||||
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
|
||||
|
||||
set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
|
||||
message(STATUS "ANDROID_STL:${ANDROID_STL}")
|
||||
file(GLOB pytorch_testapp_SOURCES
|
||||
${pytorch_testapp_cpp_DIR}/pytorch_testapp_jni.cpp
|
||||
)
|
||||
|
||||
add_library(${PROJECT_NAME} SHARED
|
||||
${pytorch_testapp_SOURCES}
|
||||
)
|
||||
|
||||
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
|
||||
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
|
||||
|
||||
target_compile_options(${PROJECT_NAME} PRIVATE
|
||||
-fexceptions
|
||||
)
|
||||
|
||||
set(BUILD_SUBDIR ${ANDROID_ABI})
|
||||
|
||||
target_include_directories(${PROJECT_NAME} PRIVATE
|
||||
${PYTORCH_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
find_library(PYTORCH_LIBRARY pytorch_jni
|
||||
PATHS ${PYTORCH_LINK_DIRS}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
target_link_libraries(${PROJECT_NAME}
|
||||
${PYTORCH_LIBRARY}
|
||||
log)
|
||||
@ -1,190 +0,0 @@
|
||||
apply plugin: 'com.android.application'
|
||||
|
||||
repositories {
|
||||
jcenter()
|
||||
maven {
|
||||
url "https://oss.sonatype.org/content/repositories/snapshots"
|
||||
}
|
||||
flatDir {
|
||||
dirs 'aars'
|
||||
}
|
||||
}
|
||||
|
||||
android {
|
||||
configurations {
|
||||
extractForNativeBuild
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility 1.8
|
||||
targetCompatibility 1.8
|
||||
}
|
||||
compileSdkVersion rootProject.compileSdkVersion
|
||||
buildToolsVersion rootProject.buildToolsVersion
|
||||
defaultConfig {
|
||||
applicationId "org.pytorch.testapp"
|
||||
minSdkVersion rootProject.minSdkVersion
|
||||
targetSdkVersion rootProject.targetSdkVersion
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
ndk {
|
||||
abiFilters ABI_FILTERS.split(",")
|
||||
}
|
||||
// Commented due to dependency on local copy of pytorch_android aar to aars folder
|
||||
//externalNativeBuild {
|
||||
// cmake {
|
||||
// abiFilters ABI_FILTERS.split(",")
|
||||
// arguments "-DANDROID_STL=c++_shared"
|
||||
// }
|
||||
//}
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"")
|
||||
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
|
||||
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
|
||||
buildConfigField("boolean", "NATIVE_BUILD", 'false')
|
||||
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
|
||||
buildConfigField(
|
||||
"int",
|
||||
"BUILD_LITE_INTERPRETER",
|
||||
System.env.BUILD_LITE_INTERPRETER != null ? System.env.BUILD_LITE_INTERPRETER : "1"
|
||||
)
|
||||
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
|
||||
}
|
||||
buildTypes {
|
||||
debug {
|
||||
minifyEnabled false
|
||||
debuggable true
|
||||
}
|
||||
release {
|
||||
minifyEnabled false
|
||||
}
|
||||
}
|
||||
// Commented due to dependency on local copy of pytorch_android aar to aars folder
|
||||
//externalNativeBuild {
|
||||
// cmake {
|
||||
// path "CMakeLists.txt"
|
||||
// }
|
||||
//}
|
||||
flavorDimensions "model", "build", "activity"
|
||||
productFlavors {
|
||||
mnet {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".mnet"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2.ptl\"")
|
||||
addManifestPlaceholders([APP_NAME: "MNET"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
|
||||
}
|
||||
// NB: This is not working atm https://github.com/pytorch/pytorch/issues/102966
|
||||
mnetVulkan {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".mnet_vulkan"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2_vulkan.ptl\"")
|
||||
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
|
||||
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
|
||||
}
|
||||
resnet18 {
|
||||
dimension "model"
|
||||
applicationIdSuffix ".resnet18"
|
||||
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.ptl\"")
|
||||
addManifestPlaceholders([APP_NAME: "RN18"])
|
||||
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
|
||||
}
|
||||
local {
|
||||
dimension "build"
|
||||
}
|
||||
nightly {
|
||||
dimension "build"
|
||||
}
|
||||
aar {
|
||||
dimension "build"
|
||||
}
|
||||
// Commented due to dependency on local copy of pytorch_android aar to aars folder
|
||||
//nativeBuild {
|
||||
// dimension "build"
|
||||
// buildConfigField("boolean", "NATIVE_BUILD", "true")
|
||||
//}
|
||||
camera {
|
||||
dimension "activity"
|
||||
addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"])
|
||||
}
|
||||
base {
|
||||
dimension "activity"
|
||||
sourceSets {
|
||||
main {
|
||||
java {
|
||||
exclude 'org/pytorch/testapp/CameraActivity.java'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
packagingOptions {
|
||||
doNotStrip '**.so'
|
||||
}
|
||||
|
||||
// Filtering for CI
|
||||
if (!testAppAllVariantsEnabled.toBoolean()) {
|
||||
variantFilter { variant ->
|
||||
def names = variant.flavors*.name
|
||||
if (names.contains("nightly")
|
||||
|| names.contains("camera")
|
||||
|| names.contains("aar")
|
||||
|| names.contains("nativeBuild")) {
|
||||
setIgnore(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tasks.all { task ->
|
||||
// Disable externalNativeBuild for all but nativeBuild variant
|
||||
if (task.name.startsWith('externalNativeBuild')
|
||||
&& !task.name.contains('NativeBuild')) {
|
||||
task.enabled = false
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'com.android.support:appcompat-v7:28.0.0'
|
||||
implementation 'com.facebook.soloader:nativeloader:0.10.5'
|
||||
|
||||
localImplementation project(':pytorch_android')
|
||||
localImplementation project(':pytorch_android_torchvision')
|
||||
|
||||
// Commented due to dependency on local copy of pytorch_android aar to aars folder
|
||||
//nativeBuildImplementation(name: 'pytorch_android-release', ext: 'aar')
|
||||
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
|
||||
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
|
||||
|
||||
nightlyImplementation 'org.pytorch:pytorch_android:2.2.0-SNAPSHOT'
|
||||
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:2.2.0-SNAPSHOT'
|
||||
|
||||
aarImplementation(name:'pytorch_android', ext:'aar')
|
||||
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
|
||||
aarImplementation 'com.facebook.soloader:nativeloader:0.10.5'
|
||||
aarImplementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
|
||||
|
||||
def camerax_version = "1.0.0-alpha05"
|
||||
cameraImplementation "androidx.camera:camera-core:$camerax_version"
|
||||
cameraImplementation "androidx.camera:camera-camera2:$camerax_version"
|
||||
cameraImplementation 'com.google.android.material:material:1.0.0-beta01'
|
||||
}
|
||||
|
||||
task extractAARForNativeBuild {
|
||||
doLast {
|
||||
configurations.extractForNativeBuild.files.each {
|
||||
def file = it.absoluteFile
|
||||
copy {
|
||||
from zipTree(file)
|
||||
into "$buildDir/$file.name"
|
||||
include "headers/**"
|
||||
include "jni/**"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tasks.whenTaskAdded { task ->
|
||||
if (task.name.contains('externalNativeBuild')) {
|
||||
task.dependsOn(extractAARForNativeBuild)
|
||||
}
|
||||
}
|
||||
@ -1,27 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="org.pytorch.testapp">
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:label="${APP_NAME}"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/AppTheme">
|
||||
|
||||
<activity android:name="${MAIN_ACTIVITY}">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
<uses-permission android:name="android.permission.CAMERA" />
|
||||
|
||||
<!--
|
||||
Permissions required by the Snapdragon Profiler to collect GPU metrics.
|
||||
-->
|
||||
<uses-permission android:name="android.permission.INTERNET" />
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
|
||||
</manifest>
|
||||
@ -1,3 +0,0 @@
|
||||
*
|
||||
*/
|
||||
!.gitignore
|
||||
@ -1,77 +0,0 @@
|
||||
#include <android/log.h>
|
||||
#include <pthread.h>
|
||||
#include <unistd.h>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#define ALOGI(...) \
|
||||
__android_log_print(ANDROID_LOG_INFO, "PyTorchTestAppJni", __VA_ARGS__)
|
||||
#define ALOGE(...) \
|
||||
__android_log_print(ANDROID_LOG_ERROR, "PyTorchTestAppJni", __VA_ARGS__)
|
||||
|
||||
#include "jni.h"
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace pytorch_testapp_jni {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void log(const char* m, T t) {
|
||||
std::ostringstream os;
|
||||
os << t << std::endl;
|
||||
ALOGI("%s %s", m, os.str().c_str());
|
||||
}
|
||||
|
||||
struct JITCallGuard {
|
||||
c10::InferenceMode guard;
|
||||
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static void loadAndForwardModel(JNIEnv* env, jclass, jstring jModelPath) {
|
||||
const char* modelPath = env->GetStringUTFChars(jModelPath, 0);
|
||||
assert(modelPath);
|
||||
|
||||
// To load torchscript model for mobile we need set these guards,
|
||||
// because mobile build doesn't support features like autograd for smaller
|
||||
// build size which is placed in `struct JITCallGuard` in this example. It may
|
||||
// change in future, you can track the latest changes keeping an eye in
|
||||
// android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
|
||||
JITCallGuard guard;
|
||||
torch::jit::Module module = torch::jit::load(modelPath);
|
||||
module.eval();
|
||||
torch::Tensor t = torch::randn({1, 3, 224, 224});
|
||||
log("input tensor:", t);
|
||||
c10::IValue t_out = module.forward({t});
|
||||
log("output tensor:", t_out);
|
||||
env->ReleaseStringUTFChars(jModelPath, modelPath);
|
||||
}
|
||||
} // namespace pytorch_testapp_jni
|
||||
|
||||
JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
|
||||
JNIEnv* env;
|
||||
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
jclass c =
|
||||
env->FindClass("org/pytorch/testapp/LibtorchNativeClient$NativePeer");
|
||||
if (c == nullptr) {
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
static const JNINativeMethod methods[] = {
|
||||
{"loadAndForwardModel",
|
||||
"(Ljava/lang/String;)V",
|
||||
(void*)pytorch_testapp_jni::loadAndForwardModel},
|
||||
};
|
||||
int rc = env->RegisterNatives(
|
||||
c, methods, sizeof(methods) / sizeof(JNINativeMethod));
|
||||
|
||||
if (rc != JNI_OK) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
||||
@ -1,214 +0,0 @@
|
||||
package org.pytorch.testapp;
|
||||
|
||||
import android.Manifest;
|
||||
import android.content.pm.PackageManager;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.os.HandlerThread;
|
||||
import android.os.SystemClock;
|
||||
import android.util.Log;
|
||||
import android.util.Size;
|
||||
import android.view.TextureView;
|
||||
import android.view.ViewStub;
|
||||
import android.widget.TextView;
|
||||
import android.widget.Toast;
|
||||
import androidx.annotation.Nullable;
|
||||
import androidx.annotation.UiThread;
|
||||
import androidx.annotation.WorkerThread;
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
import androidx.camera.core.CameraX;
|
||||
import androidx.camera.core.ImageAnalysis;
|
||||
import androidx.camera.core.ImageAnalysisConfig;
|
||||
import androidx.camera.core.ImageProxy;
|
||||
import androidx.camera.core.Preview;
|
||||
import androidx.camera.core.PreviewConfig;
|
||||
import androidx.core.app.ActivityCompat;
|
||||
import java.nio.FloatBuffer;
|
||||
import org.pytorch.IValue;
|
||||
import org.pytorch.MemoryFormat;
|
||||
import org.pytorch.Module;
|
||||
import org.pytorch.PyTorchAndroid;
|
||||
import org.pytorch.Tensor;
|
||||
import org.pytorch.torchvision.TensorImageUtils;
|
||||
|
||||
public class CameraActivity extends AppCompatActivity {
|
||||
private static final String TAG = BuildConfig.LOGCAT_TAG;
|
||||
private static final int TEXT_TRIM_SIZE = 4096;
|
||||
|
||||
private static final int REQUEST_CODE_CAMERA_PERMISSION = 200;
|
||||
private static final String[] PERMISSIONS = {Manifest.permission.CAMERA};
|
||||
|
||||
private long mLastAnalysisResultTime;
|
||||
|
||||
protected HandlerThread mBackgroundThread;
|
||||
protected Handler mBackgroundHandler;
|
||||
protected Handler mUIHandler;
|
||||
|
||||
private TextView mTextView;
|
||||
private StringBuilder mTextViewStringBuilder = new StringBuilder();
|
||||
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
setContentView(R.layout.activity_camera);
|
||||
mTextView = findViewById(R.id.text);
|
||||
mUIHandler = new Handler(getMainLooper());
|
||||
startBackgroundThread();
|
||||
|
||||
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA)
|
||||
!= PackageManager.PERMISSION_GRANTED) {
|
||||
ActivityCompat.requestPermissions(this, PERMISSIONS, REQUEST_CODE_CAMERA_PERMISSION);
|
||||
} else {
|
||||
setupCameraX();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onPostCreate(@Nullable Bundle savedInstanceState) {
|
||||
super.onPostCreate(savedInstanceState);
|
||||
startBackgroundThread();
|
||||
}
|
||||
|
||||
protected void startBackgroundThread() {
|
||||
mBackgroundThread = new HandlerThread("ModuleActivity");
|
||||
mBackgroundThread.start();
|
||||
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onDestroy() {
|
||||
stopBackgroundThread();
|
||||
super.onDestroy();
|
||||
}
|
||||
|
||||
protected void stopBackgroundThread() {
|
||||
mBackgroundThread.quitSafely();
|
||||
try {
|
||||
mBackgroundThread.join();
|
||||
mBackgroundThread = null;
|
||||
mBackgroundHandler = null;
|
||||
} catch (InterruptedException e) {
|
||||
Log.e(TAG, "Error on stopping background thread", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onRequestPermissionsResult(
|
||||
int requestCode, String[] permissions, int[] grantResults) {
|
||||
if (requestCode == REQUEST_CODE_CAMERA_PERMISSION) {
|
||||
if (grantResults[0] == PackageManager.PERMISSION_DENIED) {
|
||||
Toast.makeText(
|
||||
this,
|
||||
"You can't use image classification example without granting CAMERA permission",
|
||||
Toast.LENGTH_LONG)
|
||||
.show();
|
||||
finish();
|
||||
} else {
|
||||
setupCameraX();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static final int TENSOR_WIDTH = 224;
|
||||
private static final int TENSOR_HEIGHT = 224;
|
||||
|
||||
private void setupCameraX() {
|
||||
final TextureView textureView =
|
||||
((ViewStub) findViewById(R.id.camera_texture_view_stub))
|
||||
.inflate()
|
||||
.findViewById(R.id.texture_view);
|
||||
final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
|
||||
final Preview preview = new Preview(previewConfig);
|
||||
preview.setOnPreviewOutputUpdateListener(
|
||||
new Preview.OnPreviewOutputUpdateListener() {
|
||||
@Override
|
||||
public void onUpdated(Preview.PreviewOutput output) {
|
||||
textureView.setSurfaceTexture(output.getSurfaceTexture());
|
||||
}
|
||||
});
|
||||
|
||||
final ImageAnalysisConfig imageAnalysisConfig =
|
||||
new ImageAnalysisConfig.Builder()
|
||||
.setTargetResolution(new Size(TENSOR_WIDTH, TENSOR_HEIGHT))
|
||||
.setCallbackHandler(mBackgroundHandler)
|
||||
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
|
||||
.build();
|
||||
final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
|
||||
imageAnalysis.setAnalyzer(
|
||||
new ImageAnalysis.Analyzer() {
|
||||
@Override
|
||||
public void analyze(ImageProxy image, int rotationDegrees) {
|
||||
if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) {
|
||||
return;
|
||||
}
|
||||
|
||||
final Result result = CameraActivity.this.analyzeImage(image, rotationDegrees);
|
||||
|
||||
if (result != null) {
|
||||
mLastAnalysisResultTime = SystemClock.elapsedRealtime();
|
||||
CameraActivity.this.runOnUiThread(
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
CameraActivity.this.handleResult(result);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
CameraX.bindToLifecycle(this, preview, imageAnalysis);
|
||||
}
|
||||
|
||||
private Module mModule;
|
||||
private FloatBuffer mInputTensorBuffer;
|
||||
private Tensor mInputTensor;
|
||||
|
||||
@WorkerThread
|
||||
@Nullable
|
||||
protected Result analyzeImage(ImageProxy image, int rotationDegrees) {
|
||||
Log.i(TAG, String.format("analyzeImage(%s, %d)", image, rotationDegrees));
|
||||
if (mModule == null) {
|
||||
Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'");
|
||||
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
|
||||
mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * TENSOR_WIDTH * TENSOR_HEIGHT);
|
||||
mInputTensor =
|
||||
Tensor.fromBlob(mInputTensorBuffer, new long[] {1, 3, TENSOR_WIDTH, TENSOR_HEIGHT});
|
||||
}
|
||||
|
||||
final long startTime = SystemClock.elapsedRealtime();
|
||||
TensorImageUtils.imageYUV420CenterCropToFloatBuffer(
|
||||
image.getImage(),
|
||||
rotationDegrees,
|
||||
TENSOR_WIDTH,
|
||||
TENSOR_HEIGHT,
|
||||
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
|
||||
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
|
||||
mInputTensorBuffer,
|
||||
0,
|
||||
MemoryFormat.CHANNELS_LAST);
|
||||
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
|
||||
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
|
||||
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
|
||||
|
||||
final float[] scores = outputTensor.getDataAsFloatArray();
|
||||
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
|
||||
|
||||
return new Result(scores, moduleForwardDuration, analysisDuration);
|
||||
}
|
||||
|
||||
@UiThread
|
||||
protected void handleResult(Result result) {
|
||||
int ixs[] = Utils.topK(result.scores, 1);
|
||||
String message =
|
||||
String.format(
|
||||
"forwardDuration:%d class:%s",
|
||||
result.moduleForwardDuration, Constants.IMAGENET_CLASSES[ixs[0]]);
|
||||
Log.i(TAG, message);
|
||||
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
|
||||
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
|
||||
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
|
||||
}
|
||||
mTextView.setText(mTextViewStringBuilder.toString());
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,22 +0,0 @@
|
||||
package org.pytorch.testapp;
|
||||
|
||||
import com.facebook.soloader.nativeloader.NativeLoader;
|
||||
import com.facebook.soloader.nativeloader.SystemDelegate;
|
||||
|
||||
public final class LibtorchNativeClient {
|
||||
|
||||
public static void loadAndForwardModel(final String modelPath) {
|
||||
NativePeer.loadAndForwardModel(modelPath);
|
||||
}
|
||||
|
||||
private static class NativePeer {
|
||||
static {
|
||||
if (!NativeLoader.isInitialized()) {
|
||||
NativeLoader.init(new SystemDelegate());
|
||||
}
|
||||
NativeLoader.loadLibrary("pytorch_testapp_jni");
|
||||
}
|
||||
|
||||
private static native void loadAndForwardModel(final String modelPath);
|
||||
}
|
||||
}
|
||||
@ -1,171 +0,0 @@
|
||||
package org.pytorch.testapp;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.os.HandlerThread;
|
||||
import android.os.SystemClock;
|
||||
import android.util.Log;
|
||||
import android.widget.TextView;
|
||||
import androidx.annotation.Nullable;
|
||||
import androidx.annotation.UiThread;
|
||||
import androidx.annotation.WorkerThread;
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
import java.io.File;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.nio.FloatBuffer;
|
||||
import org.pytorch.Device;
|
||||
import org.pytorch.IValue;
|
||||
import org.pytorch.MemoryFormat;
|
||||
import org.pytorch.Module;
|
||||
import org.pytorch.PyTorchAndroid;
|
||||
import org.pytorch.Tensor;
|
||||
|
||||
public class MainActivity extends AppCompatActivity {
|
||||
|
||||
private static final String TAG = BuildConfig.LOGCAT_TAG;
|
||||
private static final int TEXT_TRIM_SIZE = 4096;
|
||||
|
||||
private TextView mTextView;
|
||||
|
||||
protected HandlerThread mBackgroundThread;
|
||||
protected Handler mBackgroundHandler;
|
||||
private Module mModule;
|
||||
private FloatBuffer mInputTensorBuffer;
|
||||
private Tensor mInputTensor;
|
||||
private StringBuilder mTextViewStringBuilder = new StringBuilder();
|
||||
|
||||
private final Runnable mModuleForwardRunnable =
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
final Result result = doModuleForward();
|
||||
runOnUiThread(
|
||||
new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
handleResult(result);
|
||||
if (mBackgroundHandler != null) {
|
||||
mBackgroundHandler.post(mModuleForwardRunnable);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
public static String assetFilePath(Context context, String assetName) {
|
||||
File file = new File(context.getFilesDir(), assetName);
|
||||
if (file.exists() && file.length() > 0) {
|
||||
return file.getAbsolutePath();
|
||||
}
|
||||
|
||||
try (InputStream is = context.getAssets().open(assetName)) {
|
||||
try (OutputStream os = new FileOutputStream(file)) {
|
||||
byte[] buffer = new byte[4 * 1024];
|
||||
int read;
|
||||
while ((read = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, read);
|
||||
}
|
||||
os.flush();
|
||||
}
|
||||
return file.getAbsolutePath();
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Error process asset " + assetName + " to file path");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
if (BuildConfig.NATIVE_BUILD) {
|
||||
final String modelFileAbsoluteFilePath =
|
||||
new File(assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath();
|
||||
LibtorchNativeClient.loadAndForwardModel(modelFileAbsoluteFilePath);
|
||||
return;
|
||||
}
|
||||
setContentView(R.layout.activity_main);
|
||||
mTextView = findViewById(R.id.text);
|
||||
startBackgroundThread();
|
||||
mBackgroundHandler.post(mModuleForwardRunnable);
|
||||
}
|
||||
|
||||
protected void startBackgroundThread() {
|
||||
mBackgroundThread = new HandlerThread(TAG + "_bg");
|
||||
mBackgroundThread.start();
|
||||
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onDestroy() {
|
||||
stopBackgroundThread();
|
||||
super.onDestroy();
|
||||
}
|
||||
|
||||
protected void stopBackgroundThread() {
|
||||
mBackgroundThread.quitSafely();
|
||||
try {
|
||||
mBackgroundThread.join();
|
||||
mBackgroundThread = null;
|
||||
mBackgroundHandler = null;
|
||||
} catch (InterruptedException e) {
|
||||
Log.e(TAG, "Error stopping background thread", e);
|
||||
}
|
||||
}
|
||||
|
||||
@WorkerThread
|
||||
@Nullable
|
||||
protected Result doModuleForward() {
|
||||
if (mModule == null) {
|
||||
final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
|
||||
long numElements = 1;
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
numElements *= shape[i];
|
||||
}
|
||||
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
|
||||
mInputTensor =
|
||||
Tensor.fromBlob(
|
||||
mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST);
|
||||
PyTorchAndroid.setNumThreads(1);
|
||||
mModule =
|
||||
BuildConfig.USE_VULKAN_DEVICE
|
||||
? PyTorchAndroid.loadModuleFromAsset(
|
||||
getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
|
||||
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
|
||||
}
|
||||
|
||||
final long startTime = SystemClock.elapsedRealtime();
|
||||
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
|
||||
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
|
||||
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
|
||||
final float[] scores = outputTensor.getDataAsFloatArray();
|
||||
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
|
||||
return new Result(scores, moduleForwardDuration, analysisDuration);
|
||||
}
|
||||
|
||||
static class Result {
|
||||
|
||||
private final float[] scores;
|
||||
private final long totalDuration;
|
||||
private final long moduleForwardDuration;
|
||||
|
||||
public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
|
||||
this.scores = scores;
|
||||
this.moduleForwardDuration = moduleForwardDuration;
|
||||
this.totalDuration = totalDuration;
|
||||
}
|
||||
}
|
||||
|
||||
@UiThread
|
||||
protected void handleResult(Result result) {
|
||||
String message = String.format("forwardDuration:%d", result.moduleForwardDuration);
|
||||
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
|
||||
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
|
||||
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
|
||||
}
|
||||
mTextView.setText(mTextViewStringBuilder.toString());
|
||||
}
|
||||
}
|
||||
@ -1,14 +0,0 @@
|
||||
package org.pytorch.testapp;
|
||||
|
||||
class Result {
|
||||
|
||||
public final float[] scores;
|
||||
public final long totalDuration;
|
||||
public final long moduleForwardDuration;
|
||||
|
||||
public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
|
||||
this.scores = scores;
|
||||
this.moduleForwardDuration = moduleForwardDuration;
|
||||
this.totalDuration = totalDuration;
|
||||
}
|
||||
}
|
||||
@ -1,28 +0,0 @@
|
||||
package org.pytorch.testapp;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
public class Utils {
|
||||
|
||||
public static int[] topK(float[] a, final int topk) {
|
||||
float values[] = new float[topk];
|
||||
Arrays.fill(values, -Float.MAX_VALUE);
|
||||
int ixs[] = new int[topk];
|
||||
Arrays.fill(ixs, -1);
|
||||
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
for (int j = 0; j < topk; j++) {
|
||||
if (a[i] > values[j]) {
|
||||
for (int k = topk - 1; k >= j + 1; k--) {
|
||||
values[k] = values[k - 1];
|
||||
ixs[k] = ixs[k - 1];
|
||||
}
|
||||
values[j] = a[i];
|
||||
ixs[j] = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ixs;
|
||||
}
|
||||
}
|
||||
@ -1,23 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<FrameLayout
|
||||
xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
tools:context=".CameraActivity">
|
||||
|
||||
<ViewStub
|
||||
android:id="@+id/camera_texture_view_stub"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:layout="@layout/texture_view"/>
|
||||
|
||||
<TextView
|
||||
android:id="@+id/text"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:layout_gravity="top"
|
||||
android:textSize="16sp"
|
||||
android:textStyle="bold"
|
||||
android:textColor="#ff0000"/>
|
||||
</FrameLayout>
|
||||
@ -1,17 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
tools:context=".MainActivity">
|
||||
|
||||
<TextView
|
||||
android:id="@+id/text"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:layout_gravity="top"
|
||||
android:textSize="14sp"
|
||||
android:background="@android:color/black"
|
||||
android:textColor="@android:color/white" />
|
||||
|
||||
</FrameLayout>
|
||||
@ -1,5 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<TextureView xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:id="@+id/texture_view"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="0dp" />
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 2.0 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 2.7 KiB |
@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="colorPrimary">#008577</color>
|
||||
<color name="colorPrimaryDark">#00574B</color>
|
||||
<color name="colorAccent">#D81B60</color>
|
||||
</resources>
|
||||
@ -1,3 +0,0 @@
|
||||
<resources>
|
||||
<string name="app_name">PyTest</string>
|
||||
</resources>
|
||||
@ -1,11 +0,0 @@
|
||||
<resources>
|
||||
|
||||
<!-- Base application theme. -->
|
||||
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
|
||||
<!-- Customize your theme here. -->
|
||||
<item name="colorPrimary">@color/colorPrimary</item>
|
||||
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
|
||||
<item name="colorAccent">@color/colorAccent</item>
|
||||
</style>
|
||||
|
||||
</resources>
|
||||
@ -1,24 +0,0 @@
|
||||
from torchvision import models
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
print(torch.version.__version__)
|
||||
|
||||
resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
||||
resnet18.eval()
|
||||
resnet18_traced = torch.jit.trace(resnet18, torch.rand(1, 3, 224, 224)).save(
|
||||
"app/src/main/assets/resnet18.pt"
|
||||
)
|
||||
|
||||
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
|
||||
resnet50.eval()
|
||||
torch.jit.trace(resnet50, torch.rand(1, 3, 224, 224)).save(
|
||||
"app/src/main/assets/resnet50.pt"
|
||||
)
|
||||
|
||||
mobilenet2q = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
|
||||
mobilenet2q.eval()
|
||||
torch.jit.trace(mobilenet2q, torch.rand(1, 3, 224, 224)).save(
|
||||
"app/src/main/assets/mobilenet2q.pt"
|
||||
)
|
||||
@ -1,27 +0,0 @@
|
||||
"""
|
||||
This is a script for PyTorch Android custom selective build test. It prepares
|
||||
MobileNetV2 TorchScript model, and dumps root ops used by the model for custom
|
||||
build script to create a tailored build which only contains these used ops.
|
||||
"""
|
||||
|
||||
import yaml
|
||||
from torchvision import models
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Download and trace the model.
|
||||
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
|
||||
model.eval()
|
||||
example = torch.rand(1, 3, 224, 224)
|
||||
# TODO: create script model with `torch.jit.script`
|
||||
traced_script_module = torch.jit.trace(model, example)
|
||||
|
||||
# Save traced TorchScript model.
|
||||
traced_script_module.save("MobileNetV2.pt")
|
||||
|
||||
# Dump root ops used by the model (for custom build optimization).
|
||||
ops = torch.jit.export_opnames(traced_script_module)
|
||||
|
||||
with open("MobileNetV2.yaml", "w") as output:
|
||||
yaml.dump(ops, output)
|
||||
@ -471,7 +471,7 @@ struct CachingHostAllocatorImpl {
|
||||
virtual B* get_free_block(size_t size) {
|
||||
auto index = size_index(size);
|
||||
std::lock_guard<std::mutex> g(free_list_[index].mutex_);
|
||||
if (free_list_[index].list_.size() > 0) {
|
||||
if (!free_list_[index].list_.empty()) {
|
||||
B* block = free_list_[index].list_.back();
|
||||
free_list_[index].list_.pop_back();
|
||||
block->allocated_ = true;
|
||||
|
||||
@ -639,7 +639,7 @@ IntArrayRef MPSHeapAllocatorImpl::getBufferShape(const void* ptr) {
|
||||
std::lock_guard<std::recursive_mutex> lock(m_mutex);
|
||||
|
||||
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
|
||||
if (buffer_block && buffer_block->shape.size() > 0) {
|
||||
if (buffer_block && !buffer_block->shape.empty()) {
|
||||
return IntArrayRef{buffer_block->shape};
|
||||
}
|
||||
return IntArrayRef();
|
||||
|
||||
@ -1456,7 +1456,6 @@ static inline at::MemoryFormat determine_backend_memory_format(
|
||||
}
|
||||
break;
|
||||
case ConvBackend::Mps:
|
||||
case ConvBackend::MpsTranspose:
|
||||
if (mps_conv_use_channels_last(input, weight)) {
|
||||
#ifdef USE_MPS
|
||||
if (!mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) {
|
||||
|
||||
@ -880,7 +880,7 @@ struct FullBidirectionalLayer
|
||||
step_inputs = input_w.unbind(0);
|
||||
auto fw_result = layer_(
|
||||
step_inputs, input_hidden.first, params.first, true);
|
||||
TORCH_CHECK(fw_result.outputs.size() > 0, "Expected sequence length to be larger than 0 in RNN");
|
||||
TORCH_CHECK(!fw_result.outputs.empty(), "Expected sequence length to be larger than 0 in RNN");
|
||||
auto fw_output = at::stack(fw_result.outputs, 0);
|
||||
input_w = params.second.linear_ih(input);
|
||||
step_inputs = input_w.unbind(0);
|
||||
@ -895,7 +895,7 @@ struct FullBidirectionalLayer
|
||||
|
||||
step_inputs = input.unbind(0);
|
||||
auto fw_result = layer_(step_inputs, input_hidden.first, params.first);
|
||||
TORCH_CHECK(fw_result.outputs.size() > 0, "Expected sequence length to be larger than 0 in RNN");
|
||||
TORCH_CHECK(!fw_result.outputs.empty(), "Expected sequence length to be larger than 0 in RNN");
|
||||
auto fw_output = at::stack(fw_result.outputs, 0);
|
||||
auto rev_step_inputs = reverse(std::move(step_inputs));
|
||||
auto rev_result =
|
||||
|
||||
@ -485,13 +485,13 @@ void _assert_async_cpu(const Tensor& self) {
|
||||
void _assert_async_msg_cpu(const Tensor& self, std::string_view assert_msg) {
|
||||
TORCH_CHECK(
|
||||
native::is_nonzero(self),
|
||||
assert_msg != "" ? assert_msg : "Assertion is failed");
|
||||
!assert_msg.empty() ? assert_msg : "Assertion is failed");
|
||||
}
|
||||
|
||||
void _assert_scalar(const Scalar& scalar, std::string_view assert_msg) {
|
||||
TORCH_SYM_CHECK(
|
||||
scalar.toSymBool(),
|
||||
assert_msg != "" ? assert_msg : "Assertion is failed");
|
||||
!assert_msg.empty() ? assert_msg : "Assertion is failed");
|
||||
}
|
||||
|
||||
Tensor _functional_assert_scalar(
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#ifdef _WIN32
|
||||
#define _USE_MATH_DEFINES
|
||||
#include <cmath>
|
||||
#include <math.h>
|
||||
#endif // _WIN32
|
||||
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
@ -67,7 +67,7 @@ void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t di
|
||||
// - tensors dtype is Double or Float
|
||||
template <typename TensorListType>
|
||||
bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) {
|
||||
TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
|
||||
TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
|
||||
const Tensor& first_tensor = tensors[0];
|
||||
// stack dimension should be in range [0,firstTensor.dim())
|
||||
// dim == firstTensor.dim() is a valid input, but it is handled by default code path
|
||||
|
||||
@ -467,9 +467,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
alpha,
|
||||
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
|
||||
activation_to_gemm_and_blas_arg(activation));
|
||||
}
|
||||
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
} else {
|
||||
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
|
||||
args.transa == 't',
|
||||
args.transb == 't',
|
||||
args.m,
|
||||
@ -486,7 +485,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
args.result->data_ptr<scalar_t>(),
|
||||
args.result_ld,
|
||||
activation_to_gemm_and_blas_arg(activation)
|
||||
);
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
if (!okay) {
|
||||
|
||||
@ -433,7 +433,7 @@ struct TensorDescriptorListParams {
|
||||
int64_t batch_sizes_sum; // == sum(batch_sizes)
|
||||
|
||||
bool is_input_packed() const {
|
||||
return batch_sizes.size() != 0;
|
||||
return !batch_sizes.empty();
|
||||
}
|
||||
|
||||
void set(
|
||||
@ -465,8 +465,7 @@ struct TensorDescriptorListParams {
|
||||
#ifndef USE_CUDNN_RNN_V8_API
|
||||
// TODO: check x for consistency with input_size?
|
||||
std::vector<TensorDescriptor> descriptors(Tensor x) const {
|
||||
auto is_input_packed = batch_sizes.size() != 0;
|
||||
if (is_input_packed) {
|
||||
if (is_input_packed()) {
|
||||
return rnn_descriptor_sequence(x, batch_sizes);
|
||||
} else {
|
||||
return rnn_descriptor(x[0], seq_length);
|
||||
@ -474,8 +473,7 @@ struct TensorDescriptorListParams {
|
||||
}
|
||||
#else
|
||||
auto descriptors(Tensor x) const {
|
||||
auto is_input_packed = batch_sizes.size() != 0;
|
||||
if (is_input_packed) {
|
||||
if (is_input_packed()) {
|
||||
return rnn_descriptor_sequence(
|
||||
x, mini_batch, batch_sizes, seq_length, x.size(-1));
|
||||
} else {
|
||||
@ -1253,7 +1251,7 @@ int64_t _cudnn_rnn_flatten_weight_prologue(
|
||||
// typeMetaToScalarType is a surprisingly nontrivial function. We should
|
||||
// avoid it if we can.
|
||||
TORCH_CHECK(
|
||||
weight_arr.size() > 0,
|
||||
!weight_arr.empty(),
|
||||
"copy_weights_to_flat_buf_views: cannot flatten empty weight list");
|
||||
|
||||
rnn.set(
|
||||
@ -1306,7 +1304,7 @@ copy_weights_to_flat_buf_views(
|
||||
bool set_orig_weights_to_flat_buf,
|
||||
bool allow_type_change /*=false*/,
|
||||
bool include_bias /*=true*/) {
|
||||
TORCH_CHECK(weight_arr.size() > 0, "empty weight list");
|
||||
TORCH_CHECK(!weight_arr.empty(), "empty weight list");
|
||||
auto handle = getCudnnHandle();
|
||||
RNNDescriptorParams rnn;
|
||||
RNNDescriptor rnn_desc;
|
||||
@ -1390,7 +1388,7 @@ Tensor _cudnn_rnn_flatten_weight(
|
||||
int64_t fn_num_layers,
|
||||
bool batch_first,
|
||||
bool fn_bidirectional) {
|
||||
TORCH_CHECK(weight_arr.size() > 0, "empty weight list");
|
||||
TORCH_CHECK(!weight_arr.empty(), "empty weight list");
|
||||
// returns flat weight_buf
|
||||
return std::get<0>(copy_weights_to_flat_buf_views(
|
||||
weight_arr,
|
||||
@ -1417,7 +1415,7 @@ Tensor _cudnn_rnn_flatten_weight_meta(
|
||||
int64_t num_layers,
|
||||
bool batch_first,
|
||||
bool bidirectional) {
|
||||
TORCH_CHECK(weight_arr.size() > 0, "empty weight list");
|
||||
TORCH_CHECK(!weight_arr.empty(), "empty weight list");
|
||||
auto handle = getCudnnHandle();
|
||||
RNNDescriptorParams rnn;
|
||||
RNNDescriptor rnn_desc;
|
||||
@ -1498,7 +1496,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
|
||||
datatype);
|
||||
#else
|
||||
auto input_size = input_r.size(-1);
|
||||
auto packed = fn_batch_sizes.size() != 0;
|
||||
auto packed = !fn_batch_sizes.empty();
|
||||
fn.rnn.set(
|
||||
fn_mode,
|
||||
input_size,
|
||||
@ -1520,7 +1518,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
|
||||
}
|
||||
|
||||
// TODO: can batch_first be a wrapper around this function?
|
||||
auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
|
||||
auto is_input_packed = !fn.tensors.batch_sizes.empty();
|
||||
if (batch_first && !is_input_packed) {
|
||||
input = input.transpose(0, 1);
|
||||
}
|
||||
@ -1775,7 +1773,7 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
|
||||
datatype);
|
||||
#else
|
||||
auto cudnn_input_size = input_r.size(-1);
|
||||
auto packed = fn_batch_sizes.size() != 0;
|
||||
auto packed = !fn_batch_sizes.empty();
|
||||
fn.rnn.set(
|
||||
fn_mode,
|
||||
cudnn_input_size,
|
||||
@ -1797,7 +1795,7 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
|
||||
TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
|
||||
}
|
||||
|
||||
auto is_input_packed = fn_batch_sizes.size() != 0;
|
||||
auto is_input_packed = !fn_batch_sizes.empty();
|
||||
if (batch_first && !is_input_packed) {
|
||||
input = input.transpose(0, 1);
|
||||
grad_output = grad_output.transpose(0, 1);
|
||||
@ -2004,7 +2002,7 @@ std::vector<Tensor> _cudnn_rnn_backward_weight(
|
||||
datatype);
|
||||
#else
|
||||
auto cudnn_input_size = input_r.size(-1);
|
||||
auto packed = fn_batch_sizes.size() != 0;
|
||||
auto packed = !fn_batch_sizes.empty();
|
||||
fn.rnn.set(
|
||||
fn_mode,
|
||||
cudnn_input_size,
|
||||
@ -2025,7 +2023,7 @@ std::vector<Tensor> _cudnn_rnn_backward_weight(
|
||||
TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
|
||||
}
|
||||
|
||||
auto is_input_packed = fn_batch_sizes.size() != 0;
|
||||
auto is_input_packed = !fn_batch_sizes.empty();
|
||||
if (batch_first && !is_input_packed) {
|
||||
input = input.transpose(0, 1);
|
||||
output = output.transpose(0, 1);
|
||||
|
||||
@ -203,7 +203,7 @@ struct TensorDescriptorListParams {
|
||||
int64_t batch_sizes_sum;
|
||||
|
||||
bool is_input_packed() const {
|
||||
return batch_sizes.size() != 0;
|
||||
return !batch_sizes.empty();
|
||||
}
|
||||
|
||||
void set(IntArrayRef input_sizes, IntArrayRef batch_sizes_, bool batch_first) {
|
||||
@ -227,8 +227,7 @@ struct TensorDescriptorListParams {
|
||||
}
|
||||
|
||||
std::vector<TensorDescriptor> descriptors(Tensor x) const {
|
||||
auto is_input_packed = batch_sizes.size() != 0;
|
||||
if (is_input_packed) {
|
||||
if (is_input_packed()) {
|
||||
return rnn_descriptor_sequence(x, batch_sizes);
|
||||
} else {
|
||||
return rnn_descriptor(x[0], seq_length);
|
||||
@ -545,7 +544,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
|
||||
TORCH_CHECK(!cx.defined(), "miopen_rnn: illegal defined cx for non-LSTM RNN.");
|
||||
}
|
||||
|
||||
auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
|
||||
auto is_input_packed = !fn.tensors.batch_sizes.empty();
|
||||
if (batch_first && !is_input_packed) {
|
||||
input = input.transpose(0, 1);
|
||||
}
|
||||
@ -656,7 +655,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
|
||||
TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
|
||||
}
|
||||
|
||||
auto is_input_packed = fn_batch_sizes.size() != 0;
|
||||
auto is_input_packed = !fn_batch_sizes.empty();
|
||||
if (batch_first && !is_input_packed) {
|
||||
input = input.transpose(0, 1);
|
||||
grad_output = grad_output.transpose(0, 1);
|
||||
@ -773,7 +772,7 @@ std::vector<Tensor> miopen_rnn_backward_weight(
|
||||
TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
|
||||
}
|
||||
|
||||
auto is_input_packed = fn_batch_sizes.size() != 0;
|
||||
auto is_input_packed = !fn_batch_sizes.empty();
|
||||
if (batch_first && !is_input_packed) {
|
||||
input = input.transpose(0, 1);
|
||||
output = output.transpose(0, 1);
|
||||
|
||||
@ -457,7 +457,7 @@ static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
|
||||
int64_t mode, int64_t hidden_size,
|
||||
int64_t num_layers, bool has_biases, bool batch_first, double dropout_p,
|
||||
bool train, bool bidirectional, IntArrayRef batch_sizes) {
|
||||
TORCH_CHECK(batch_sizes.size() == 0, "mkldnn_rnn doesn't support packed input");
|
||||
TORCH_CHECK(batch_sizes.empty(), "mkldnn_rnn doesn't support packed input");
|
||||
if (static_cast<ideep::rnn_kind>(mode) != ideep::rnn_kind::LSTM) {
|
||||
TORCH_CHECK(!cx_.defined(), "mkldnn_rnn: illegal defined cx for non-LSTM RNN");
|
||||
}
|
||||
|
||||
@ -1160,7 +1160,7 @@ void MetalKernelFunction::dispatch(uint64_t length, std::optional<uint64_t> grou
|
||||
}
|
||||
|
||||
void MetalKernelFunction::dispatch(c10::ArrayRef<uint64_t> length, c10::OptionalArrayRef<uint64_t> group_size) {
|
||||
TORCH_CHECK(length.size() > 0 && length.size() < 4, "Dispatch dimentions must be less than 3 and non-empty");
|
||||
TORCH_CHECK(!length.empty() && length.size() < 4, "Dispatch dimentions must be less than 3 and non-empty");
|
||||
TORCH_CHECK(!group_size.has_value() || group_size->size() == length.size(),
|
||||
"size and group_size must have same number of dimentions");
|
||||
const auto max_tg_size = getMaxThreadsPerThreadgroup();
|
||||
|
||||
@ -180,7 +180,7 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
|
||||
usage:MTLResourceUsageRead | MTLResourceUsageWrite];
|
||||
}
|
||||
if (state_steps.size() > 0) {
|
||||
if (!state_steps.empty()) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc);
|
||||
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
|
||||
}
|
||||
@ -230,7 +230,7 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
|
||||
usage:MTLResourceUsageWrite | MTLResourceUsageRead];
|
||||
}
|
||||
if (state_steps.size() > 0) {
|
||||
if (!state_steps.empty()) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
|
||||
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
|
||||
}
|
||||
|
||||
@ -493,7 +493,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
// Reduction axes
|
||||
axes = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
|
||||
axes[0] = @0;
|
||||
} else if (!keepdim && use_dim && dim_value.size() > 0) {
|
||||
} else if (!keepdim && use_dim && !dim_value.empty()) {
|
||||
int64_t num_reduce_dims = dim_value.size();
|
||||
num_output_dims = num_input_dims;
|
||||
|
||||
@ -528,7 +528,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
correction_n *= input_shape[wrap_dim];
|
||||
}
|
||||
// (3, 4, 5) --> (3, 5)
|
||||
} else if ((keepdim && !use_dim) || (keepdim && use_dim && dim_value.size() <= 0)) {
|
||||
} else if ((keepdim && !use_dim) || (keepdim && use_dim && dim_value.empty())) {
|
||||
num_output_dims = 0;
|
||||
int64_t num_reduce_dims = 0;
|
||||
set_axes(axes, num_reduce_dims, dim_value, input_shape.size());
|
||||
@ -540,7 +540,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
correction_n *= input_shape[i];
|
||||
}
|
||||
// scalar --> vector case [[1.0034567]]
|
||||
} else if (keepdim && use_dim && dim_value.size() > 0) {
|
||||
} else if (keepdim && use_dim && !dim_value.empty()) {
|
||||
int64_t num_reduce_dims = dim_value.size();
|
||||
num_output_dims = num_input_dims;
|
||||
|
||||
|
||||
@ -168,7 +168,7 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
TORCH_CHECK(canCast(out_dtype, out.scalar_type()),
|
||||
"torch.cat(): input types can't be cast to the desired output type ",
|
||||
out.scalar_type());
|
||||
TORCH_CHECK(inputs.size() > 0, "torch.cat(): invalid number of inputs ", inputs.size());
|
||||
TORCH_CHECK(!inputs.empty(), "torch.cat(): invalid number of inputs ", inputs.size());
|
||||
|
||||
dimension = legacy_cat_wrap_dim(dimension, materialized_inputs);
|
||||
TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);
|
||||
|
||||
@ -272,7 +272,7 @@ static std::tuple<Tensor, Tensor, Tensor> _unique_impl_mps(const Tensor& self,
|
||||
}
|
||||
|
||||
int64_t lengthScalar = length.item<int64_t>() + 1; // length actually holds max index, add 1
|
||||
if (output.sizes().size() != 0) {
|
||||
if (!output.sizes().empty()) {
|
||||
output = at::slice(output, dim, 0, lengthScalar);
|
||||
}
|
||||
if (return_counts)
|
||||
|
||||
@ -215,7 +215,7 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self,
|
||||
// when we create/run the view graph.
|
||||
IntArrayRef base_shape = mps::updateTensorBaseShape(self);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data());
|
||||
!base_shape.empty(), "Failed to update the base shape of tensor's buffer at ", self.storage().data());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -6981,7 +6981,7 @@
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: rsub
|
||||
CPU, CUDA, MPS: rsub
|
||||
autogen: rsub.Tensor_out
|
||||
|
||||
- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
||||
@ -429,7 +429,7 @@ Tensor sparse_compressed_tensor_with_dims(
|
||||
compressed_indices_size.push_back(compressed_size / blocksize[d0] + 1);
|
||||
values_size.append(DimVector(blocksize));
|
||||
} else {
|
||||
TORCH_CHECK(blocksize.size() == 0, "sparse_compressed_tensor_with_dims: blocksize cannot be specified for non-block layout ", layout_);
|
||||
TORCH_CHECK(blocksize.empty(), "sparse_compressed_tensor_with_dims: blocksize cannot be specified for non-block layout ", layout_);
|
||||
compressed_indices_size.push_back(size[compressedDimension(layout_, size, dense_dim)] + 1);
|
||||
}
|
||||
|
||||
|
||||
@ -573,7 +573,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_
|
||||
}
|
||||
|
||||
const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size);
|
||||
const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0);
|
||||
const bool sum_dense_dim = !dense_dims_to_sum_v.empty();
|
||||
const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0);
|
||||
|
||||
if (sum_all_sparse_dim) {
|
||||
|
||||
@ -479,7 +479,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
@ -705,7 +705,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
@ -940,7 +940,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
@ -1163,7 +1163,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
|
||||
@ -1393,7 +1393,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ class BatchNormPackedContext final : virtual public VulkanPackedContext,
|
||||
static BatchNormPackedContext pack(c10::impl::GenericList);
|
||||
|
||||
const c10::impl::GenericList unpack() const override {
|
||||
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
|
||||
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
|
||||
|
||||
return unpacked_;
|
||||
}
|
||||
|
||||
@ -1283,7 +1283,7 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const {
|
||||
Conv2dOpContext::State Conv2dOpContext::unpack() const {
|
||||
const c10::impl::GenericList unpacked_ = conv_context_.unpack();
|
||||
|
||||
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
|
||||
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
|
||||
|
||||
return Conv2dOpContext::State(
|
||||
unpacked_.get(Conv2dPackedContext::Unpacked::Weight).toTensor(),
|
||||
|
||||
@ -115,7 +115,7 @@ class Conv2dPackedContext final : virtual public VulkanPackedContext,
|
||||
static Conv2dPackedContext pack(c10::impl::GenericList);
|
||||
|
||||
const c10::impl::GenericList unpack() const override {
|
||||
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
|
||||
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
|
||||
|
||||
return unpacked_;
|
||||
}
|
||||
@ -275,7 +275,7 @@ class Conv1dPackedContext final : virtual public VulkanPackedContext,
|
||||
static Conv1dPackedContext pack(c10::impl::GenericList);
|
||||
|
||||
const c10::impl::GenericList unpack() const override {
|
||||
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
|
||||
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
|
||||
|
||||
return unpacked_;
|
||||
}
|
||||
|
||||
@ -240,7 +240,7 @@ const c10::impl::GenericList GruPackedContext::unpack() const {
|
||||
packed_linear_context.toCustomClass<LinearPackedContext>()->unpack();
|
||||
|
||||
TORCH_CHECK(
|
||||
unpacked_linear_context.size() > 0u,
|
||||
!unpacked_linear_context.empty(),
|
||||
"unpacked_linear_context does not have any elements!");
|
||||
|
||||
params_cpu.emplace_back(
|
||||
|
||||
@ -282,7 +282,7 @@ const c10::impl::GenericList LstmPackedContext::unpack() const {
|
||||
packed_linear_context.toCustomClass<LinearPackedContext>()->unpack();
|
||||
|
||||
TORCH_CHECK(
|
||||
unpacked_linear_context.size() > 0u,
|
||||
!unpacked_linear_context.empty(),
|
||||
"unpacked_linear_context does not have any elements!");
|
||||
|
||||
params_cpu.emplace_back(
|
||||
|
||||
@ -89,7 +89,7 @@ class LinearPackedContext final : virtual public VulkanPackedContext,
|
||||
static LinearPackedContext pack(c10::impl::GenericList);
|
||||
|
||||
const c10::impl::GenericList unpack() const override {
|
||||
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
|
||||
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
|
||||
|
||||
return unpacked_;
|
||||
}
|
||||
|
||||
@ -17,10 +17,13 @@ if(BUILD_TEST)
|
||||
if(WIN32 AND test_src MATCHES "^.*\.hip$")
|
||||
set_source_files_properties(${test_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
hip_add_executable(${test_name} "${test_src}")
|
||||
set_target_properties(${test_name} PROPERTIES LINKER_LANGUAGE CXX HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH})
|
||||
set_target_properties(${test_name} PROPERTIES HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH})
|
||||
else()
|
||||
add_executable(${test_name} "${test_src}")
|
||||
endif()
|
||||
if(test_src MATCHES "^.*\.hip$")
|
||||
set_target_properties(${test_name} PROPERTIES LINKER_LANGUAGE CXX)
|
||||
endif()
|
||||
target_link_libraries(${test_name} ${C10_CUDA_LIB} ${C10_LIB} gmock gtest gtest_main)
|
||||
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
|
||||
if(INSTALL_TEST)
|
||||
|
||||
@ -1939,10 +1939,13 @@ if(BUILD_TEST)
|
||||
set(HIP_HIPCC_FLAGS ${BASE_HIPCC_FLAGS})
|
||||
set_source_files_properties(${test_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
hip_add_executable(${test_name} "${test_src}")
|
||||
set_target_properties(${test_name} PROPERTIES LINKER_LANGUAGE CXX HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH})
|
||||
set_target_properties(${test_name} PROPERTIES HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH})
|
||||
else()
|
||||
add_executable(${test_name} "${test_src}")
|
||||
endif()
|
||||
if(test_src MATCHES "^.*\.hip$")
|
||||
set_target_properties(${test_name} PROPERTIES LINKER_LANGUAGE CXX)
|
||||
endif()
|
||||
target_link_libraries(${test_name} torch_library gtest_main)
|
||||
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE} ${Caffe2_HIP_INCLUDE})
|
||||
|
||||
@ -51,6 +51,10 @@ else()
|
||||
set(XPU_ENABLE_KINETO FALSE)
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
if(WIN32)
|
||||
if(${SYCL_COMPILER_VERSION} GREATER_EQUAL 20250101)
|
||||
set(XPU_ENABLE_KINETO TRUE)
|
||||
endif()
|
||||
else()
|
||||
set(XPU_ENABLE_KINETO TRUE)
|
||||
endif()
|
||||
@ -859,3 +859,5 @@ API Reference
|
||||
.. automodule:: torch.export.experimental
|
||||
.. automodule:: torch.export.passes
|
||||
.. autofunction:: torch.export.passes.move_to_device_pass
|
||||
.. automodule:: torch.export.pt2_archive
|
||||
.. automodule:: torch.export.pt2_archive.constants
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from torch import cond # noqa: F401
|
||||
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
|
||||
from torch._higher_order_ops.map import ( # noqa: F401
|
||||
_stack_pytree,
|
||||
_unstack_pytree,
|
||||
|
||||
@ -2384,15 +2384,6 @@
|
||||
"torch.utils.collect_env": [
|
||||
"namedtuple"
|
||||
],
|
||||
"torch.utils.data.datapipes.gen_pyi": [
|
||||
"Any",
|
||||
"Dict",
|
||||
"List",
|
||||
"Set",
|
||||
"Tuple",
|
||||
"Union",
|
||||
"defaultdict"
|
||||
],
|
||||
"torch.utils.data.datapipes.utils.snapshot": [
|
||||
"IterDataPipe",
|
||||
"apply_random_seed"
|
||||
|
||||
@ -19,7 +19,7 @@ target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
|
||||
|
||||
# WARNING: DO NOT LINK torch!!!
|
||||
# The purpose is to check if the used aten/c10 headers are writtern in a header-only way
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest)
|
||||
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
|
||||
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
if(INSTALL_TEST)
|
||||
|
||||
@ -55,7 +55,7 @@ add_custom_command(
|
||||
|
||||
target_link_libraries(test_aoti_inference PRIVATE
|
||||
torch
|
||||
gtest
|
||||
gtest_main
|
||||
-Wl,--no-as-needed aoti_custom_class
|
||||
)
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ endif()
|
||||
|
||||
add_executable(test_api ${TORCH_API_TEST_SOURCES})
|
||||
target_include_directories(test_api PRIVATE ${ATen_CPU_INCLUDE})
|
||||
target_link_libraries(test_api PRIVATE torch gtest gmock)
|
||||
target_link_libraries(test_api PRIVATE torch gtest_main gmock)
|
||||
|
||||
if(USE_CUDA)
|
||||
target_compile_definitions(test_api PRIVATE "USE_CUDA")
|
||||
|
||||
@ -7,7 +7,7 @@ if(USE_DISTRIBUTED AND NOT WIN32)
|
||||
|
||||
add_executable(test_dist_autograd ${DIST_AUTOGRAD_TEST_SOURCES})
|
||||
target_include_directories(test_dist_autograd PRIVATE ${ATen_CPU_INCLUDE})
|
||||
target_link_libraries(test_dist_autograd PRIVATE torch gtest)
|
||||
target_link_libraries(test_dist_autograd PRIVATE torch gtest_main)
|
||||
|
||||
if(USE_CUDA)
|
||||
target_compile_definitions(test_dist_autograd PRIVATE "USE_CUDA")
|
||||
|
||||
@ -125,7 +125,7 @@ if(USE_MKLDNN)
|
||||
target_link_libraries(test_jit PRIVATE caffe2::mkldnn)
|
||||
endif()
|
||||
|
||||
set(JIT_TEST_DEPENDENCIES torch gtest jitbackend_test backend_with_compiler gmock)
|
||||
set(JIT_TEST_DEPENDENCIES torch gtest_main jitbackend_test backend_with_compiler gmock)
|
||||
|
||||
if(MSVC)
|
||||
list(APPEND JIT_TEST_DEPENDENCIES onnx_library)
|
||||
|
||||
@ -28,7 +28,7 @@ add_executable(test_lazy
|
||||
# TODO temporary until we can delete the old gtest polyfills.
|
||||
target_compile_definitions(test_lazy PRIVATE USE_GTEST)
|
||||
|
||||
set(LAZY_TEST_DEPENDENCIES torch gtest)
|
||||
set(LAZY_TEST_DEPENDENCIES torch gtest_main)
|
||||
|
||||
target_link_libraries(test_lazy PRIVATE ${LAZY_TEST_DEPENDENCIES})
|
||||
target_include_directories(test_lazy PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
@ -21,7 +21,7 @@ target_include_directories(
|
||||
${ATen_CPU_INCLUDE}
|
||||
)
|
||||
|
||||
target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest backend_with_compiler_runtime)
|
||||
target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest_main backend_with_compiler_runtime)
|
||||
|
||||
if(LINUX)
|
||||
target_link_libraries(test_lite_interpreter_runtime PRIVATE "-Wl,--no-as-needed,$<TARGET_FILE:backend_with_compiler_runtime>,--as-needed")
|
||||
|
||||
@ -17,7 +17,7 @@ add_executable(test_nativert
|
||||
# TODO temporary until we can delete the old gtest polyfills.
|
||||
target_compile_definitions(test_nativert PRIVATE USE_GTEST)
|
||||
|
||||
set(NATIVERT_TEST_DEPENDENCIES torch gtest)
|
||||
set(NATIVERT_TEST_DEPENDENCIES torch gtest_main)
|
||||
|
||||
target_link_libraries(test_nativert PRIVATE ${NATIVERT_TEST_DEPENDENCIES})
|
||||
target_link_libraries(test_nativert PRIVATE fmt::fmt-header-only)
|
||||
|
||||
@ -5,7 +5,7 @@ set(TORCH_RPC_TEST_SOURCES
|
||||
${TORCH_RPC_TEST_DIR}/test_wire_serialization.cpp
|
||||
)
|
||||
set(TORCH_RPC_TEST_DEPENDENCY_LIBS
|
||||
torch gtest
|
||||
torch gtest_main
|
||||
)
|
||||
|
||||
if(USE_GLOO)
|
||||
|
||||
@ -39,7 +39,7 @@ add_executable(test_tensorexpr
|
||||
${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp
|
||||
${TENSOREXPR_TEST_SRCS})
|
||||
|
||||
target_link_libraries(test_tensorexpr PRIVATE torch gtest)
|
||||
target_link_libraries(test_tensorexpr PRIVATE torch gtest_main)
|
||||
target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE})
|
||||
target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import collections
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -12,13 +11,13 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.nn.parallel.scatter_gather import _is_namedtuple
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
DoubleLinear,
|
||||
FSDPTest,
|
||||
FSDPTestMultiThread,
|
||||
get_devtype,
|
||||
MLP,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
@ -28,10 +27,13 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestFullyShardAutograd(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
def _reduce_1d_partial_grads(
|
||||
self, module: nn.Module, group: Optional[dist.ProcessGroup] = None
|
||||
@ -58,7 +60,7 @@ class TestFullyShardAutograd(FSDPTest):
|
||||
local_batch_size = 2
|
||||
global_batch_size, dim = (self.world_size * local_batch_size, 24)
|
||||
model = DoubleLinear(dim=dim, use_second_linear=True)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
fully_shard(model.lin1, reshard_after_forward=reshard_after_forward)
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
@ -68,7 +70,7 @@ class TestFullyShardAutograd(FSDPTest):
|
||||
for iter_idx in range(10):
|
||||
# Use all forward outputs in the loss/backward for the first half
|
||||
# of the iterations and only the 1st forward output for the rest
|
||||
global_inp = torch.rand((global_batch_size, dim), device="cuda")
|
||||
global_inp = torch.rand((global_batch_size, dim), device=device_type)
|
||||
local_inp = global_inp[
|
||||
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
|
||||
].detach()
|
||||
@ -104,7 +106,7 @@ class TestFullyShardAutograd(FSDPTest):
|
||||
local_batch_size, dim = (2, 24)
|
||||
global_batch_size = self.world_size * local_batch_size
|
||||
model = DoubleLinear(dim=dim, use_second_linear=False)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
fully_shard(model.lin1, reshard_after_forward=reshard_after_forward)
|
||||
fully_shard(model.lin2, reshard_after_forward=reshard_after_forward)
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward)
|
||||
@ -113,7 +115,7 @@ class TestFullyShardAutograd(FSDPTest):
|
||||
|
||||
torch.manual_seed(1) # same on all ranks
|
||||
for iter_idx in range(10):
|
||||
global_inp = torch.rand((global_batch_size, dim), device="cuda")
|
||||
global_inp = torch.rand((global_batch_size, dim), device=device_type)
|
||||
local_inp = global_inp[
|
||||
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
|
||||
].detach()
|
||||
@ -214,7 +216,7 @@ class TestFullyShardAutograd(FSDPTest):
|
||||
Module(dim),
|
||||
FromContainerType(container_type),
|
||||
)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
for module in model:
|
||||
fully_shard(module)
|
||||
fully_shard(model)
|
||||
@ -223,7 +225,7 @@ class TestFullyShardAutograd(FSDPTest):
|
||||
|
||||
torch.manual_seed(1) # same on all ranks
|
||||
for iter_idx in range(10):
|
||||
global_inp = torch.rand((global_batch_size, dim), device="cuda")
|
||||
global_inp = torch.rand((global_batch_size, dim), device=device_type)
|
||||
local_inp = global_inp[
|
||||
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
|
||||
].detach()
|
||||
@ -245,7 +247,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_post_acc_grad_hook_runs(self):
|
||||
param_name_to_hook_count = collections.defaultdict(int)
|
||||
|
||||
@ -260,7 +262,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
|
||||
param_hook = functools.partial(hook, param_name)
|
||||
param.register_post_accumulate_grad_hook(param_hook)
|
||||
|
||||
inp = torch.randn((2, 8), device="cuda")
|
||||
inp = torch.randn((2, 8), device=device_type)
|
||||
model(inp).sum().backward()
|
||||
param_names = {param_name for param_name, _ in model.named_parameters()}
|
||||
self.assertEqual(param_names, set(param_name_to_hook_count.keys()))
|
||||
@ -271,7 +273,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
|
||||
class TestFullyShardPostAccGradHookMultiProcess(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.cuda.device_count(), 2)
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_post_acc_grad_hook_optim_parity(self):
|
||||
@ -283,7 +285,7 @@ class TestFullyShardPostAccGradHookMultiProcess(FSDPTest):
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
for module in itertools.chain(ref_model.layers, [ref_model]):
|
||||
fully_shard(module)
|
||||
optim_kwargs = {"lr": 1e-2, "foreach": False}
|
||||
@ -312,7 +314,7 @@ class TestFullyShardPostAccGradHookMultiProcess(FSDPTest):
|
||||
param.register_post_accumulate_grad_hook(optim_hook)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type)
|
||||
for _ in range(10):
|
||||
ref_loss = ref_model(inp).sum()
|
||||
ref_loss.backward()
|
||||
|
||||
@ -11,7 +11,7 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLPStack
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
ModelArgs,
|
||||
@ -20,6 +20,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class _TestClipGradNormBase(FSDPTest):
|
||||
def _test_clip_grad_norm(
|
||||
self,
|
||||
@ -33,7 +36,7 @@ class _TestClipGradNormBase(FSDPTest):
|
||||
dp_mesh: Optional[DeviceMesh] = None,
|
||||
):
|
||||
vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
|
||||
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
|
||||
dp_mesh = dp_mesh or init_device_mesh(device_type.type, (self.world_size,))
|
||||
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
|
||||
for _ in range(10):
|
||||
ref_optim.zero_grad()
|
||||
@ -91,7 +94,7 @@ class _TestClipGradNormBase(FSDPTest):
|
||||
class TestClipGradNormWorldSize2(_TestClipGradNormBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.cuda.device_count(), 2)
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_clip_grad_norm_1d(self):
|
||||
@ -99,14 +102,16 @@ class TestClipGradNormWorldSize2(_TestClipGradNormBase):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
ref_model = replicate(copy.deepcopy(model).cuda())
|
||||
ref_model = replicate(copy.deepcopy(model).to(device_type))
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
fully_shard(module)
|
||||
fully_shard(model)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda")
|
||||
inp = torch.randint(
|
||||
0, model.model_args.vocab_size, (3, 16), device=device_type
|
||||
)
|
||||
self._test_clip_grad_norm(
|
||||
1, norm_type, ref_model, ref_optim, model, optim, inp
|
||||
)
|
||||
@ -115,14 +120,14 @@ class TestClipGradNormWorldSize2(_TestClipGradNormBase):
|
||||
class TestClipGradNormWorldSize4(_TestClipGradNormBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.cuda.device_count(), 4)
|
||||
return min(torch.get_device_module(device_type).device_count(), 4)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_clip_grad_norm_2d(self):
|
||||
for norm_type in (2, 1, 3, float("inf")):
|
||||
dp_size = 2
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda",
|
||||
device_type.type,
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
@ -132,7 +137,7 @@ class TestClipGradNormWorldSize4(_TestClipGradNormBase):
|
||||
# has some more significant numeric differences from the TP
|
||||
model = MLPStack(16, with_seq_parallel=True)
|
||||
ref_model = replicate(
|
||||
copy.deepcopy(model).cuda(), process_group=dp_mesh.get_group()
|
||||
copy.deepcopy(model).to(device_type), process_group=dp_mesh.get_group()
|
||||
)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
model.parallelize(
|
||||
@ -142,7 +147,7 @@ class TestClipGradNormWorldSize4(_TestClipGradNormBase):
|
||||
reshard_after_forward=True,
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
inp = torch.randn(2, 16, device="cuda")
|
||||
inp = torch.randn(2, 16, device=device_type)
|
||||
self._test_clip_grad_norm(
|
||||
0.5, norm_type, ref_model, ref_optim, model, optim, inp, dp_mesh
|
||||
)
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -35,7 +34,6 @@ from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
@ -60,6 +58,12 @@ c10d_ops = torch.ops.c10d
|
||||
# For recording FSDP events like unshard or post-backward
|
||||
EventType = tuple[str, str, TrainingState]
|
||||
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
device_module = torch.get_device_module(device_type)
|
||||
|
||||
|
||||
class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
@property
|
||||
@ -68,7 +72,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cuda:0")
|
||||
return torch.device(device_type.type, 0)
|
||||
|
||||
def _get_param_sizes(self) -> list[torch.Size]:
|
||||
# For world size 128, the fp32 all-gather and reduce-scatter testing
|
||||
@ -116,11 +120,14 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
fsdp_param_group.lazy_init()
|
||||
return fsdp_param_group
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_all_gather_fp32(self):
|
||||
param_sizes = self._get_param_sizes()
|
||||
default_stream = torch.cuda.current_stream()
|
||||
stream1, stream2 = torch.cuda.Stream(), torch.cuda.Stream()
|
||||
default_stream = device_module.current_stream()
|
||||
stream1, stream2 = (
|
||||
device_module.Stream(),
|
||||
device_module.Stream(),
|
||||
)
|
||||
for async_op, streams, reshard_after_forward in itertools.product(
|
||||
(False, True),
|
||||
((default_stream, default_stream), (stream1, stream2)),
|
||||
@ -146,8 +153,8 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
param_sizes: list[torch.Size],
|
||||
reshard_after_forward: Union[bool, int],
|
||||
async_op: bool,
|
||||
all_gather_copy_in_stream: torch.cuda.Stream,
|
||||
all_gather_stream: torch.cuda.Stream,
|
||||
all_gather_copy_in_stream,
|
||||
all_gather_stream,
|
||||
):
|
||||
def all_gather(fsdp_param_group: FSDPParamGroup, group: dist.ProcessGroup):
|
||||
all_gather_result = foreach_all_gather(
|
||||
@ -202,11 +209,11 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
)
|
||||
check_all_gathered_params(orig_params, module)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_reduce_scatter_fp32(self):
|
||||
param_sizes = self._get_param_sizes()
|
||||
default_stream = torch.cuda.current_stream()
|
||||
stream = torch.cuda.Stream()
|
||||
default_stream = device_module.current_stream()
|
||||
stream = device_module.Stream()
|
||||
for reduce_scatter_stream in (default_stream, stream):
|
||||
self._test_reduce_scatter(
|
||||
param_sizes,
|
||||
@ -214,11 +221,11 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
reduce_scatter_dtype=torch.float32,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_reduce_scatter_fp16(self):
|
||||
param_sizes = self._get_param_sizes()
|
||||
default_stream = torch.cuda.current_stream()
|
||||
stream = torch.cuda.Stream()
|
||||
default_stream = torch.get_device_module(device_type).current_stream()
|
||||
stream = device_module.Stream()
|
||||
for reduce_scatter_stream in (default_stream, stream):
|
||||
self._test_reduce_scatter(
|
||||
param_sizes,
|
||||
@ -229,7 +236,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
def _test_reduce_scatter(
|
||||
self,
|
||||
param_sizes: list[torch.Size],
|
||||
reduce_scatter_stream: torch.cuda.Stream,
|
||||
reduce_scatter_stream,
|
||||
reduce_scatter_dtype: torch.dtype,
|
||||
):
|
||||
# Set up the reference parameters and construct the FSDP group
|
||||
@ -248,7 +255,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
unsharded_grads = [torch.ones_like(param) * self.rank for param in orig_params]
|
||||
group = fsdp_param_group.mesh_info.shard_process_group
|
||||
self.assertEqual(group.size(), self.world_size)
|
||||
all_reduce_stream = torch.cuda.Stream()
|
||||
all_reduce_stream = device_module.Stream()
|
||||
(
|
||||
_,
|
||||
_,
|
||||
@ -271,7 +278,9 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
all_reduce_grads=True,
|
||||
partial_reduce_output=None,
|
||||
)
|
||||
torch.cuda.current_stream().wait_event(post_reduce_event)
|
||||
torch.get_device_module(device_type).current_stream().wait_event(
|
||||
post_reduce_event
|
||||
)
|
||||
|
||||
# Check reduce-scatter correctness
|
||||
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
|
||||
@ -295,7 +304,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
class TestFullyShardCommunication(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fully_shard_communication_count(self):
|
||||
@ -327,7 +336,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
# We construct `num_blocks` plus 1 FSDP states/communication groups
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
with CommDebugMode() as fwd_comm_mode:
|
||||
loss = model(inp)
|
||||
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
|
||||
@ -364,7 +373,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
with CommDebugMode() as fwd_comm_mode:
|
||||
loss = model(inp)
|
||||
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
|
||||
@ -395,7 +404,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(dropout_p=0.0, weight_tying=False)
|
||||
model = Transformer(model_args)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
@ -405,7 +414,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
model.set_reduce_scatter_divide_factor(divide_factor)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
|
||||
for _ in range(10):
|
||||
ref_loss = ref_model(inp).sum()
|
||||
@ -441,7 +450,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs()
|
||||
model = Transformer(model_args)
|
||||
model = Transformer(model_args).to(device_type)
|
||||
fully_shard_fn = functools.partial(
|
||||
fully_shard, reshard_after_forward=not set_reshard_after_forward
|
||||
)
|
||||
@ -459,7 +468,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
with CommDebugMode() as fwd_comm_mode:
|
||||
loss = model(inp)
|
||||
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
|
||||
@ -484,7 +493,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
class TestFullyShardPrefetch(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fully_shard_backward_prefetch(self):
|
||||
@ -640,7 +649,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||
fully_shard(model[1].lin1, reshard_after_forward=reshard_after_forward)
|
||||
fully_shard(model[1].lin2, reshard_after_forward=reshard_after_forward)
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward)
|
||||
inp = torch.randn((4, dim), device="cuda")
|
||||
inp = torch.randn((4, dim), device=device_type.type)
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
@ -901,7 +910,10 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||
FSDPParamGroup.post_backward, events
|
||||
)
|
||||
inp = torch.randint(
|
||||
0, model_args.vocab_size, (2, model_args.max_seq_len), device="cuda"
|
||||
0,
|
||||
model_args.vocab_size,
|
||||
(2, model_args.max_seq_len),
|
||||
device=device_type.type,
|
||||
)
|
||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
||||
post_backward_with_record
|
||||
@ -981,7 +993,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||
post_backward_with_record = self._get_post_backward_with_record(
|
||||
FSDPParamGroup.post_backward, events
|
||||
)
|
||||
inp = torch.randn((2, 16), device="cuda")
|
||||
inp = torch.randn((2, 16), device=device_type.type)
|
||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
||||
post_backward_with_record
|
||||
):
|
||||
@ -1019,7 +1031,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_backward_misprefetch(self):
|
||||
torch.manual_seed(42)
|
||||
model = MLP(dim=16, device="cuda")
|
||||
model = MLP(dim=16, device=device_type)
|
||||
ref_model = copy.deepcopy(model)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
fully_shard(model.in_proj)
|
||||
@ -1033,7 +1045,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||
model.in_proj.set_modules_to_backward_prefetch([model.out_proj])
|
||||
|
||||
torch.manual_seed(self.rank + 1)
|
||||
inp = torch.randn((2, 16), device="cuda")
|
||||
inp = torch.randn((2, 16), device=device_type.type)
|
||||
for _ in range(3):
|
||||
ref_optim.zero_grad()
|
||||
ref_loss = ref_model(inp).sum()
|
||||
@ -1065,7 +1077,10 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
inp = torch.randint(
|
||||
0, model_args.vocab_size, (2, model_args.max_seq_len), device="cuda"
|
||||
0,
|
||||
model_args.vocab_size,
|
||||
(2, model_args.max_seq_len),
|
||||
device=device_type.type,
|
||||
)
|
||||
return model, optim, inp
|
||||
|
||||
@ -1115,7 +1130,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||
class TestFullyShardUnshardMultiProcess(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.cuda.device_count(), 2)
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_unshard_async(self):
|
||||
@ -1169,10 +1184,10 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
|
||||
self.mlps.mlp3.unshard(async_op=True)
|
||||
return self.mlps([y1, y2, y3], [work1, work2, work3])
|
||||
|
||||
mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
batch_size, dim = 2, 8
|
||||
torch.manual_seed(42)
|
||||
ref_model = replicate(ReduceModel(dim, mesh).cuda())
|
||||
ref_model = replicate(ReduceModel(dim, mesh).to(device_type))
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
torch.manual_seed(42)
|
||||
model = ReduceModel(dim, mesh)
|
||||
@ -1180,10 +1195,10 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
|
||||
fully_shard(model.mlps.mlp2, reshard_after_forward=False)
|
||||
fully_shard(model.mlps.mlp3, reshard_after_forward=False)
|
||||
fully_shard(model.mlps)
|
||||
replicate(model.cuda())
|
||||
replicate(model.to(device_type))
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((batch_size, dim), device="cuda")
|
||||
inp = torch.randn((batch_size, dim), device=device_type.type)
|
||||
for _ in range(10):
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
@ -1200,7 +1215,7 @@ class TestFullyShardUnshardMultiThread(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_unshard_no_param_group(self):
|
||||
# Check that we can call `unshard()` on a module with no parameter
|
||||
# group / no managed parameters without erroring
|
||||
@ -1211,7 +1226,7 @@ class TestFullyShardUnshardMultiThread(FSDPTestMultiThread):
|
||||
handle = model.unshard(async_op=True)
|
||||
handle.wait()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_unshard_without_lazy_init(self):
|
||||
torch.manual_seed(42)
|
||||
model = MLP(4)
|
||||
|
||||
@ -31,7 +31,7 @@ from torch.testing._internal.common_distributed import (
|
||||
skip_if_lt_x_gpu,
|
||||
sm_is_or_higher_than,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, MLP
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
ModelArgs,
|
||||
@ -40,6 +40,8 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -59,9 +61,9 @@ class Mod(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = torch.nn.Sequential(
|
||||
torch.nn.Linear(28 * 28, 1024, device="cuda"),
|
||||
torch.nn.Linear(1024, 1024, device="cuda"),
|
||||
torch.nn.Linear(1024, 4096, device="cuda"),
|
||||
torch.nn.Linear(28 * 28, 1024, device=device_type),
|
||||
torch.nn.Linear(1024, 1024, device=device_type),
|
||||
torch.nn.Linear(1024, 4096, device=device_type),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -104,10 +106,10 @@ class TestFullyShardCompileCompute(FSDPTest):
|
||||
torch.distributed.barrier()
|
||||
torch._dynamo.config.skip_fsdp_hooks = skip_fsdp_hooks
|
||||
torch._dynamo.trace_rules.check = patched_trace_rules_check
|
||||
model = MLP(4)
|
||||
model = MLP(4).to(device_type)
|
||||
fully_shard(model)
|
||||
model.compile()
|
||||
model(torch.randn((4, 4), device="cuda"))
|
||||
model(torch.randn((4, 4), device=device_type))
|
||||
torch.distributed.barrier()
|
||||
torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks
|
||||
torch._dynamo.trace_rules.check = orig_trace_rules_check
|
||||
@ -127,7 +129,10 @@ class TestFullyShardCompile(FSDPTest):
|
||||
def skipTestForOldSm(self):
|
||||
# Assumption: This test class is only run on GPU. See `HAS_GPU` check at
|
||||
# the top of the class.
|
||||
device = torch.device("cuda", self.rank % torch.cuda.device_count())
|
||||
device = torch.device(
|
||||
device_type.type,
|
||||
self.rank % torch.get_device_module(device_type).device_count(),
|
||||
)
|
||||
if not sm_is_or_higher_than(device, 8, 0):
|
||||
self.skipTest("bf16 requires sm >= 8.0")
|
||||
|
||||
@ -139,7 +144,7 @@ class TestFullyShardCompile(FSDPTest):
|
||||
(torch.nn.Linear(1, 1),), # module: Tuple[nn.Module, ...],
|
||||
None, # mesh_info: FSDPMeshInfo,
|
||||
None, # post_forward_mesh_info: Optional[FSDPMeshInfo],
|
||||
torch.device("cuda"), # device: torch.device,
|
||||
device_type, # device: torch.device,
|
||||
None, # shard_placement_fn: Optional[Callable],
|
||||
None, # mp_policy: MixedPrecisionPolicy,
|
||||
None, # offload_policy: OffloadPolicy,
|
||||
@ -592,11 +597,11 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
torch.manual_seed(self.rank)
|
||||
fsdp_config = {}
|
||||
model = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
|
||||
nn.Linear(hidden_dim, hidden_dim, device=device_type),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
|
||||
nn.Linear(hidden_dim, hidden_dim, device=device_type),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
|
||||
nn.Linear(hidden_dim, hidden_dim, device=device_type),
|
||||
)
|
||||
fully_shard(model, reshard_after_forward=True, **fsdp_config)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-4)
|
||||
@ -604,7 +609,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
|
||||
def input_creation_fn():
|
||||
torch.manual_seed(self.rank)
|
||||
inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
|
||||
inp = torch.randn((2, hidden_dim), device=device_type, requires_grad=False)
|
||||
return inp
|
||||
|
||||
return model_init_fn, input_creation_fn
|
||||
@ -641,11 +646,11 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
super().__init__()
|
||||
self.param1 = nn.Parameter(
|
||||
torch.zeros(
|
||||
hidden_dim, hidden_dim, dtype=torch.float, device="cuda"
|
||||
hidden_dim, hidden_dim, dtype=torch.float, device=device_type
|
||||
)
|
||||
)
|
||||
self.param2 = nn.Parameter(
|
||||
torch.zeros(hidden_dim, dtype=torch.float, device="cuda")
|
||||
torch.zeros(hidden_dim, dtype=torch.float, device=device_type)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -680,7 +685,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
def model_init_fn():
|
||||
torch.manual_seed(self.rank)
|
||||
fsdp_config = {}
|
||||
mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
model = TestModule(n_layers=3)
|
||||
for mod in model.layers:
|
||||
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
|
||||
@ -692,7 +697,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
|
||||
def input_creation_fn():
|
||||
torch.manual_seed(self.rank)
|
||||
inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
|
||||
inp = torch.randn((2, hidden_dim), device=device_type, requires_grad=False)
|
||||
return inp
|
||||
|
||||
return model_init_fn, input_creation_fn
|
||||
@ -854,7 +859,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
def model_init_fn():
|
||||
torch.manual_seed(self.rank)
|
||||
fsdp_config = {}
|
||||
mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
model_args = ModelArgs(
|
||||
vocab_size=vocab_size,
|
||||
n_layers=n_layers,
|
||||
@ -883,7 +888,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
def input_creation_fn():
|
||||
torch.manual_seed(self.rank)
|
||||
inp = torch.randint(
|
||||
0, vocab_size, (2, seq_len), device="cuda", requires_grad=False
|
||||
0, vocab_size, (2, seq_len), device=device_type, requires_grad=False
|
||||
)
|
||||
return inp
|
||||
|
||||
@ -1092,7 +1097,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
new_child = torch.compile(child)
|
||||
setattr(m.encoder, name, new_child)
|
||||
m = FSDP(m, sharding_strategy=ShardingStrategy.FULL_SHARD, use_orig_params=True)
|
||||
inp = torch.randn(32, 784, device="cuda")
|
||||
inp = torch.randn(32, 784, device=device_type)
|
||||
m(inp)
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ import copy
|
||||
import functools
|
||||
import math
|
||||
import threading
|
||||
import unittest
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -15,18 +14,21 @@ import torch.utils._pytree as pytree
|
||||
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
FSDPTest,
|
||||
FSDPTestMultiThread,
|
||||
get_devtype,
|
||||
MLP,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
def two_tensor_fsdp_pre_all_gather_v1(
|
||||
self, mesh: DeviceMesh
|
||||
) -> tuple[tuple[torch.Tensor, ...], Any]:
|
||||
@ -222,7 +224,7 @@ class TestFullyShardAllGatherExtensionsMultiProcess(
|
||||
def _test_all_gather_extensions_train_parity(self, reshard_after_forward: bool):
|
||||
torch.manual_seed(42)
|
||||
model = self._init_two_tensor_mlp()
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=True)
|
||||
fully_shard_fn = functools.partial(
|
||||
fully_shard, reshard_after_forward=reshard_after_forward
|
||||
@ -234,7 +236,7 @@ class TestFullyShardAllGatherExtensionsMultiProcess(
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((2, 8), device="cuda")
|
||||
inp = torch.randn((2, 8), device=device_type)
|
||||
for iter_idx in range(10):
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
@ -261,9 +263,9 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cuda:0")
|
||||
return torch.device(device_type)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_all_gather_extensions_end_to_end(self):
|
||||
with self._patch_two_tensor_fsdp_all_gather(pre_all_gather_version=1):
|
||||
self.run_subtests(
|
||||
@ -297,13 +299,13 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
||||
|
||||
# Run a few iterations to check for errors
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((2, 8), device="cuda")
|
||||
inp = torch.randn((2, 8), device=device_type)
|
||||
for _ in range(3):
|
||||
model(inp).sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_all_gather_extensions_monkey_patch(self):
|
||||
tls = threading.local()
|
||||
tls.ran_pre_all_gather = False
|
||||
@ -368,14 +370,14 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
||||
|
||||
# Run a few iterations to check for errors
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((2, 8), device="cuda")
|
||||
inp = torch.randn((2, 8), device=device_type)
|
||||
for _ in range(3):
|
||||
model(inp).sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
assert tls.ran_pre_all_gather
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_all_gather_extension_outer_size_stride(self):
|
||||
"""
|
||||
NOTE: We cannot easily test the incorrect case where the user-defined
|
||||
@ -395,19 +397,19 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
||||
fully_shard(model)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, fused=True)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((2, 3), device="cuda")
|
||||
inp = torch.randn((2, 3), device=device_type)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_all_gather_extension_hsdp_mesh(self):
|
||||
tls = threading.local()
|
||||
replicate_size = 2
|
||||
shard_size = self.world_size // replicate_size
|
||||
mesh = init_device_mesh(
|
||||
"cuda",
|
||||
device_type.type,
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard"),
|
||||
)
|
||||
@ -456,7 +458,7 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
||||
local_param
|
||||
)
|
||||
|
||||
inp = torch.randn((2, 8), device="cuda")
|
||||
inp = torch.randn((2, 8), device=device_type)
|
||||
model(inp)
|
||||
# Check that FSDP passes only the shard mesh to the pre-all-gather
|
||||
self.assertEqual(tls.mesh.ndim, 1)
|
||||
|
||||
@ -18,6 +18,7 @@ from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
FSDPTest,
|
||||
get_devtype,
|
||||
MLP,
|
||||
patch_reduce_scatter,
|
||||
patch_register_post_backward_hook_backward,
|
||||
@ -26,10 +27,13 @@ from torch.testing._internal.common_fsdp import (
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestFullyShardFrozen(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_mixed_requires_grad_per_group(self):
|
||||
@ -66,7 +70,7 @@ class TestFullyShardFrozen(FSDPTest):
|
||||
if "bias" not in param_name:
|
||||
param.requires_grad_(False)
|
||||
ref_model = replicate(
|
||||
copy.deepcopy(model).cuda(),
|
||||
copy.deepcopy(model).to(device_type),
|
||||
device_ids=[self.rank],
|
||||
find_unused_parameters=freeze_after_init,
|
||||
)
|
||||
@ -110,7 +114,7 @@ class TestFullyShardFrozen(FSDPTest):
|
||||
return orig_backward(*args, **kwargs)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
device = torch.device("cuda")
|
||||
device = device_type
|
||||
with patch_reduce_scatter(
|
||||
reduce_scatter
|
||||
), patch_register_post_backward_hook_backward(backward_with_count):
|
||||
@ -156,7 +160,7 @@ class TestFullyShardFrozen(FSDPTest):
|
||||
modules += [nn.Linear(lin_dim, lin_dim), nn.ReLU()]
|
||||
model = nn.Sequential(*modules)
|
||||
ref_model = replicate(
|
||||
copy.deepcopy(model).cuda(),
|
||||
copy.deepcopy(model).to(device_type),
|
||||
device_ids=[self.rank],
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
@ -184,7 +188,7 @@ class TestFullyShardFrozen(FSDPTest):
|
||||
_set_requires_grad(ref_model, False)
|
||||
num_iters, no_grad_iter_idx = (3, 1)
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randn((8, lin_dim), device="cuda")
|
||||
inp = torch.randn((8, lin_dim), device=device_type)
|
||||
with patch_register_post_backward_hook_backward(backward_with_count):
|
||||
for iter_idx in range(num_iters):
|
||||
losses: list[torch.Tensor] = []
|
||||
@ -242,7 +246,9 @@ class TestFullyShardFrozen(FSDPTest):
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = MultiForwardModule(torch.device("cpu"))
|
||||
ref_model = replicate(copy.deepcopy(model).cuda(), device_ids=[self.rank])
|
||||
ref_model = replicate(
|
||||
copy.deepcopy(model).to(device_type), device_ids=[self.rank]
|
||||
)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
@ -250,7 +256,7 @@ class TestFullyShardFrozen(FSDPTest):
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, 5), device="cuda")
|
||||
inp = torch.randn((8, 5), device=device_type)
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
|
||||
@ -12,10 +12,13 @@ from torch.distributed.tensor.parallel import (
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, MLP
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestFullyShardGradientScaler(FSDPTest):
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_gradient_scaler(self):
|
||||
@ -27,16 +30,16 @@ class TestFullyShardGradientScaler(FSDPTest):
|
||||
def _test_gradient_scaler(self, has_inf: bool, test_2d: bool):
|
||||
torch.manual_seed(0)
|
||||
model = nn.Sequential(
|
||||
*[nn.Linear(4, 4, device="cuda", bias=False) for _ in range(2)]
|
||||
*[nn.Linear(4, 4, device=device_type, bias=False) for _ in range(2)]
|
||||
)
|
||||
for layer in model:
|
||||
fully_shard(layer)
|
||||
fully_shard(model)
|
||||
input = torch.randn([4, 4], device="cuda")
|
||||
input = torch.randn([4, 4], device=device_type)
|
||||
|
||||
if test_2d:
|
||||
mesh_2d = init_device_mesh(
|
||||
"cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||
device_type.type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
dp_mesh, tp_mesh = mesh_2d["dp"], mesh_2d["tp"]
|
||||
model = nn.Sequential(MLP(2), MLP(2), MLP(2))
|
||||
@ -56,10 +59,10 @@ class TestFullyShardGradientScaler(FSDPTest):
|
||||
for module in model:
|
||||
fully_shard(module, mesh=dp_mesh)
|
||||
fully_shard(model, mesh=dp_mesh)
|
||||
input = torch.randn((2,), device="cuda")
|
||||
input = torch.randn((2,), device=device_type)
|
||||
|
||||
loss = model(input).sum()
|
||||
scaler = GradScaler(init_scale=2.0, enabled=True)
|
||||
scaler = GradScaler(init_scale=2.0, enabled=True, device=device_type.type)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
scaler.scale(loss).backward()
|
||||
inv_scale = scaler._scale.double().reciprocal().float()
|
||||
|
||||
@ -12,7 +12,7 @@ from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
run_tests,
|
||||
@ -20,6 +20,8 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
@ -72,14 +74,14 @@ class A(nn.Module):
|
||||
class Y(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
p = torch.randn(10, device="cuda")
|
||||
p = torch.randn(10, device=device_type)
|
||||
self.p = nn.Parameter(p)
|
||||
|
||||
|
||||
class X(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
q = torch.randn(10, device="cuda")
|
||||
q = torch.randn(10, device=device_type)
|
||||
self.q = nn.Parameter(q)
|
||||
self.y = Y()
|
||||
|
||||
@ -95,15 +97,15 @@ def _generate_model_and_input() -> nn.Module:
|
||||
dim = 8
|
||||
|
||||
torch.manual_seed(42)
|
||||
addend = torch.randn((dim, dim), device="cuda")
|
||||
addend = torch.randn((dim, dim), device=device_type)
|
||||
|
||||
torch.manual_seed(70)
|
||||
subend = torch.randn((dim, dim), device="cuda")
|
||||
subend = torch.randn((dim, dim), device=device_type)
|
||||
|
||||
model = A(dim, addend, subend).cuda()
|
||||
model = A(dim, addend, subend).to(device_type)
|
||||
|
||||
torch.manual_seed(84)
|
||||
inp = torch.randn((dim, dim), device="cuda")
|
||||
inp = torch.randn((dim, dim), device=device_type)
|
||||
|
||||
return model, inp
|
||||
|
||||
@ -229,7 +231,7 @@ class TestFullyShardIgnoreParams(FSDPTest):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_ddp_A_fsdp_B_ddp_C(self):
|
||||
default_pg = dist.distributed_c10d._get_default_group()
|
||||
mesh = init_device_mesh("cuda", mesh_shape=(default_pg.size(),))
|
||||
mesh = init_device_mesh(device_type.type, mesh_shape=(default_pg.size(),))
|
||||
|
||||
ref_model, ref_inp = _generate_model_and_input()
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
@ -37,8 +36,8 @@ from torch.distributed.tensor.parallel import (
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, get_devtype, MLP
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
ModelArgs,
|
||||
@ -47,6 +46,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestFullyShardDeviceTensor(FSDPTestMultiThread):
|
||||
"""Tests that tensor parameters are moved to the expected device."""
|
||||
|
||||
@ -54,17 +56,19 @@ class TestFullyShardDeviceTensor(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 1
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_move_states_to_device_tensor(self):
|
||||
model = MLP(8, torch.device("cpu"), with_buffer=True)
|
||||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
self.assertEqual(tensor.device, torch.device("cpu"))
|
||||
fully_shard(model)
|
||||
cuda_device = torch.device("cuda", torch.cuda.current_device())
|
||||
accelerator_device = torch.device(
|
||||
device_type.type, torch.get_device_module(device_type).current_device()
|
||||
)
|
||||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
self.assertEqual(tensor.device, cuda_device)
|
||||
self.assertEqual(tensor.device, accelerator_device)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_move_states_to_device_ignored_param_device(self):
|
||||
cpu_device = torch.device("cpu")
|
||||
model = MLP(8, cpu_device, with_buffer=True)
|
||||
@ -72,10 +76,12 @@ class TestFullyShardDeviceTensor(FSDPTestMultiThread):
|
||||
fully_shard(model, ignored_params=set(ignored_params))
|
||||
for tensor in ignored_params:
|
||||
self.assertEqual(tensor.device, cpu_device)
|
||||
cuda_device = torch.device("cuda", torch.cuda.current_device())
|
||||
model.to(torch.device("cuda"))
|
||||
accelerator_device = torch.device(
|
||||
device_type.type, torch.get_device_module(device_type).current_device()
|
||||
)
|
||||
model.to(device_type)
|
||||
for tensor in ignored_params:
|
||||
self.assertEqual(tensor.device, cuda_device)
|
||||
self.assertEqual(tensor.device, accelerator_device)
|
||||
|
||||
|
||||
class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
|
||||
@ -85,12 +91,14 @@ class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_move_states_to_device_dtensor_valid(self):
|
||||
assert self.world_size >= 4, f"{self.world_size}"
|
||||
dp_size = 2
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
device_type.type,
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
||||
model = MLP(8, torch.device("cpu"), with_buffer=True)
|
||||
@ -99,31 +107,35 @@ class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
|
||||
tp_mesh,
|
||||
{"in_proj": ColwiseParallel(), "out_proj": RowwiseParallel()},
|
||||
)
|
||||
cuda_device = torch.device("cuda", torch.cuda.current_device())
|
||||
accelerator_device = torch.device(
|
||||
device_type.type, torch.get_device_module(device_type).current_device()
|
||||
)
|
||||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
if isinstance(tensor, DTensor):
|
||||
# DTensor constructor moves to the mesh's device
|
||||
self.assertEqual(tensor.device, cuda_device)
|
||||
self.assertEqual(tensor._local_tensor.device, cuda_device)
|
||||
self.assertEqual(tensor.device, accelerator_device)
|
||||
self.assertEqual(tensor._local_tensor.device, accelerator_device)
|
||||
else:
|
||||
self.assertEqual(tensor.device, torch.device("cpu"))
|
||||
fully_shard(model, mesh=dp_mesh)
|
||||
for tensor in itertools.chain(model.parameters(), model.buffers()):
|
||||
self.assertEqual(tensor.device, cuda_device)
|
||||
self.assertEqual(tensor.device, accelerator_device)
|
||||
if isinstance(tensor, DTensor):
|
||||
self.assertEqual(tensor._local_tensor.device, cuda_device)
|
||||
self.assertEqual(tensor._local_tensor.device, accelerator_device)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_move_states_to_device_dtensor_invalid(self):
|
||||
assert self.world_size >= 4, f"{self.world_size}"
|
||||
dp_size = 2
|
||||
global_cuda_mesh = init_device_mesh(
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
global_accelerator_mesh = init_device_mesh(
|
||||
device_type.type,
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
global_cpu_mesh = init_device_mesh(
|
||||
"cpu", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
dp_mesh = global_cuda_mesh["dp"]
|
||||
dp_mesh = global_accelerator_mesh["dp"]
|
||||
tp_mesh = global_cpu_mesh["tp"] # mismatched meshes!
|
||||
model = MLP(8, torch.device("cpu"), with_buffer=True)
|
||||
parallelize_module(
|
||||
@ -135,7 +147,10 @@ class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
|
||||
self.assertEqual(tensor.device, torch.device("cpu"))
|
||||
if isinstance(tensor, DTensor):
|
||||
self.assertEqual(tensor._local_tensor.device, torch.device("cpu"))
|
||||
regex = r"Requires DTensor to have mesh of the same type as the FSDP mesh but got cpu for DTensor and cuda for FSDP"
|
||||
regex = (
|
||||
rf"Requires DTensor to have mesh of the same type as the FSDP mesh but got "
|
||||
rf"cpu for DTensor and {device_type.type} for FSDP"
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, regex):
|
||||
fully_shard(model, mesh=dp_mesh)
|
||||
|
||||
@ -147,17 +162,17 @@ class TestFullyShardMeshArg(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_invalid_mesh_ndim(self):
|
||||
mesh = init_device_mesh("cuda", (self.world_size, 1, 1))
|
||||
mesh = init_device_mesh(device_type.type, (self.world_size, 1, 1))
|
||||
model = MLP(8)
|
||||
regex = r"fully\_shard expects a 1D or 2D DeviceMesh but got DeviceMesh"
|
||||
with self.assertRaisesRegex(ValueError, regex):
|
||||
fully_shard(model, mesh=mesh)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_2d_mesh_without_mesh_dim_names(self):
|
||||
mesh = init_device_mesh("cuda", (self.world_size // 2, 2))
|
||||
mesh = init_device_mesh(device_type.type, (self.world_size // 2, 2))
|
||||
model = MLP(8)
|
||||
regex = "Please init the 2D mesh for HSDP with mesh_dim_names specified"
|
||||
with self.assertRaisesRegex(AssertionError, regex):
|
||||
@ -171,7 +186,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 1
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_managed_modules_single(self):
|
||||
model = MLP(8)
|
||||
# Assume calling `fully_shard` on `model`
|
||||
@ -179,7 +194,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
||||
expected_managed_modules = list(model.modules())
|
||||
self._check_managed_modules(managed_modules, expected_managed_modules)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_managed_modules_nested(self):
|
||||
model = nn.Sequential(*[MLP(8) for _ in range(2)])
|
||||
fully_shard(model[0])
|
||||
@ -188,7 +203,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
||||
expected_managed_modules = list(model[1].modules()) + [model]
|
||||
self._check_managed_modules(managed_modules, expected_managed_modules)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_managed_modules_nested_fully_shard_and_replicate(self):
|
||||
model = nn.Sequential(*[MLP(8) for _ in range(3)])
|
||||
replicate(model[0])
|
||||
@ -198,7 +213,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
||||
expected_managed_modules = list(model[1].modules()) + [model]
|
||||
self._check_managed_modules(managed_modules, expected_managed_modules)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_managed_modules_duplicate(self):
|
||||
mlp = MLP(8)
|
||||
model = nn.Sequential(mlp, mlp) # duplicate MLP
|
||||
@ -208,7 +223,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
||||
expected_managed_modules = list(mlp.modules()) + [model]
|
||||
self._check_managed_modules(managed_modules, expected_managed_modules)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_managed_modules_list_of_mlps(self):
|
||||
model = nn.Sequential(*[MLP(8) for _ in range(5)])
|
||||
# Assume calling `fully_shard` on `[model[0], model[1], model[2]]`
|
||||
@ -232,7 +247,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
||||
# Check set comparison since we do not require anything about the order
|
||||
self.assertEqual(set(managed_modules), set(expected_managed_modules))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_managed_states_shared_params_and_buffers(self):
|
||||
model = nn.Sequential(*[MLP(8, with_buffer=True) for _ in range(3)])
|
||||
model[0].in_proj.weight = model[1].in_proj.weight
|
||||
@ -245,7 +260,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
||||
expected_buffers = list(model.buffers()) # de-dups shared
|
||||
self._check_managed_states(params, buffers, expected_params, expected_buffers)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_managed_states_nested_fully_shard(self):
|
||||
model = nn.Sequential(*[MLP(8, with_buffer=True) for _ in range(2)])
|
||||
fully_shard(model[0])
|
||||
@ -256,7 +271,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
||||
expected_buffers = list(model[1].buffers())
|
||||
self._check_managed_states(params, buffers, expected_params, expected_buffers)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_managed_states_list_of_mlps(self):
|
||||
model = nn.Sequential(*[MLP(8, with_buffer=True) for _ in range(5)])
|
||||
# Assume calling `fully_shard` on `[model[0], model[1], model[2]]`
|
||||
@ -292,7 +307,7 @@ class TestFullyShardParamModuleInfos(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_get_param_module_infos_shared_params(self):
|
||||
model = nn.Sequential(*[MLP(8) for _ in range(2)])
|
||||
model[0].in_proj.weight = model[1].in_proj.weight
|
||||
@ -313,7 +328,7 @@ class TestFullyShardParamModuleInfos(FSDPTestMultiThread):
|
||||
self.assertEqual(len(param_module_infos), len(expected_param_module_infos))
|
||||
self.assertEqual(param_module_infos, expected_param_module_infos)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_get_param_module_infos_duplicates(self):
|
||||
mlp = MLP(8)
|
||||
model = nn.Sequential(mlp, mlp) # shared MLP
|
||||
@ -341,7 +356,7 @@ class TestFullyShardParamModuleInfos(FSDPTestMultiThread):
|
||||
ParamModuleInfo(mlp.out_proj, "bias", [], []),
|
||||
]
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_get_param_module_infos_list_of_mlps(self):
|
||||
model = nn.Sequential(*[MLP(8) for _ in range(2)])
|
||||
managed_modules = _get_managed_modules((model[0], model[1]))
|
||||
@ -367,7 +382,7 @@ class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_shard_tensor_parameters(self):
|
||||
# Use odd dim sizes to test uneven shards
|
||||
model = nn.Sequential(*[MLP(3, dim_multiplier=3) for _ in range(3)])
|
||||
@ -387,7 +402,7 @@ class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
|
||||
self, orig_params: list[nn.Parameter], sharded_params: list[nn.Parameter]
|
||||
):
|
||||
self.assertEqual(len(orig_params), len(sharded_params))
|
||||
global_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
global_mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
for orig_param, sharded_param in zip(orig_params, sharded_params):
|
||||
self.assertIsInstance(sharded_param, DTensor)
|
||||
self.assertEqual(sharded_param.device_mesh, global_mesh)
|
||||
@ -397,17 +412,19 @@ class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
|
||||
chunks = torch.chunk(orig_param, self.world_size, dim=0)
|
||||
self.assertEqual(sharded_param._local_tensor, chunks[self.rank])
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_raise_scalar_parameter(self):
|
||||
"""Tests raising an exception when the model has scalar parameters."""
|
||||
model = nn.Sequential(*[MLP(3, dim_multiplier=3) for _ in range(3)])
|
||||
model.register_parameter("scalar_p", nn.Parameter(torch.tensor(1.0).cuda()))
|
||||
model.register_parameter(
|
||||
"scalar_p", nn.Parameter(torch.tensor(1.0).to(device_type))
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Change scalar_p to a 1D tensor with numel equal to 1."
|
||||
):
|
||||
fully_shard(model)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_raise_noncontiguous_parameter(self):
|
||||
"""
|
||||
Tests raising an exception when the model has non-contiguous
|
||||
@ -425,11 +442,13 @@ class TestFullyShardShardedParameterDTensor(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_shard_dtensor_parameters(self):
|
||||
dp_size = 2 if self.world_size > 2 else 1
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
device_type.type,
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
||||
# Use odd dim sizes to test uneven shards
|
||||
@ -468,7 +487,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fully_shard_is_root(self):
|
||||
"""
|
||||
Tests that ``_is_root`` is set correctly after lazy initialization.
|
||||
@ -497,7 +516,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
|
||||
all_states, [root_state, model0_in_proj_state, model0_out_proj_state]
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fully_shard_module_and_param_fqns(self):
|
||||
"""
|
||||
Tests that the module and parameter FQNs are computed correctly after
|
||||
@ -555,7 +574,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
|
||||
model0_out_proj_param_fqns, {"0.out_proj.weight", "0.out_proj.bias"}
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fully_shard_double_lazy_init(self):
|
||||
model = nn.Sequential(MLP(8), MLP(8))
|
||||
fully_shard(model[0].in_proj)
|
||||
@ -571,7 +590,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
|
||||
with self.assertRaisesRegex(RuntimeError, regex):
|
||||
root_state._lazy_init()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fully_shard_multi_module_root(self):
|
||||
model = nn.Sequential(MLP(8), MLP(8))
|
||||
fully_shard([model[0], model[1]])
|
||||
@ -580,7 +599,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
|
||||
with self.assertRaisesRegex(RuntimeError, regex):
|
||||
root_state._lazy_init()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_reset_sharded_param_in_lazy_init(self):
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
@ -607,11 +626,11 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
|
||||
fully_shard(model.layer2)
|
||||
fully_shard(model)
|
||||
|
||||
model.layer1.to_empty(device="cuda")
|
||||
model.layer2.to_empty(device="cuda")
|
||||
model.layer1.to_empty(device=device_type.type)
|
||||
model.layer2.to_empty(device=device_type.type)
|
||||
model.init_weight_norm()
|
||||
|
||||
inp = torch.randn(3, 3, device="cuda")
|
||||
inp = torch.randn(3, 3, device=device_type.type)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
@ -621,10 +640,10 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_meta_device_1d_init(self):
|
||||
default_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
mesh = init_device_mesh("cuda", mesh_shape=(default_pg.size(),))
|
||||
mesh = init_device_mesh(device_type.type, mesh_shape=(default_pg.size(),))
|
||||
|
||||
# Test both even sharding (8) and uneven sharding (3)
|
||||
for mlp_dim in (8, 3):
|
||||
@ -652,12 +671,14 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
self.assertEqual(param.device, torch.device("meta"))
|
||||
self._test_to_empty_and_reset_parameters(model, mesh, mlp_dim)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_meta_device_2d_init(self):
|
||||
assert self.world_size >= 4, f"{self.world_size}"
|
||||
dp_size = 2
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
device_type.type,
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
||||
|
||||
@ -685,7 +706,9 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
self, model: nn.Module, mesh: DeviceMesh, mlp_dim: int
|
||||
):
|
||||
# Check that we can materialize it on GPU with empty values
|
||||
device = torch.device("cuda", torch.cuda.current_device())
|
||||
device = torch.device(
|
||||
device_type.type, torch.get_device_module(device_type).current_device()
|
||||
)
|
||||
model.to_empty(device=device)
|
||||
for param in model.parameters():
|
||||
self.assertEqual(param.device, device)
|
||||
@ -706,14 +729,14 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
self.assertNotEqual(buffer, torch.ones_like(buffer) * const)
|
||||
|
||||
# Check that we can run an iteration without erroring
|
||||
inp = torch.randn((4, mlp_dim), device="cuda")
|
||||
inp = torch.randn((4, mlp_dim), device=device_type.type)
|
||||
model(inp).sum().backward()
|
||||
optim.step()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_invalid_meta_device_init(self):
|
||||
default_pg = torch.distributed.distributed_c10d._get_default_group()
|
||||
mesh = init_device_mesh("cuda", mesh_shape=(default_pg.size(),))
|
||||
mesh = init_device_mesh(device_type.type, mesh_shape=(default_pg.size(),))
|
||||
mlp_dim = 8
|
||||
with torch.device("meta"):
|
||||
model = nn.Sequential(MLP(mlp_dim, with_buffer=True), MLP(mlp_dim))
|
||||
@ -722,7 +745,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
fully_shard(model[0], mesh=mesh)
|
||||
fully_shard(model[1], mesh=mesh)
|
||||
fully_shard(model, mesh=mesh)
|
||||
inp = torch.randn((4, mlp_dim), device="cuda")
|
||||
inp = torch.randn((4, mlp_dim), device=device_type.type)
|
||||
error_regex = (
|
||||
"FSDP parameters should be materialized from meta device before training, "
|
||||
"but the following were still on meta device: "
|
||||
@ -731,7 +754,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
with self.assertRaisesRegex(RuntimeError, error_regex):
|
||||
model(inp)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_rank0_broadcast_meta_device_init(self):
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
# Assume we have a CPU full state dict on rank 0
|
||||
@ -743,7 +766,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
self.assertEqual(param.device, torch.device("cpu"))
|
||||
|
||||
# Initialize the sharded model on meta device
|
||||
fsdp_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
fsdp_mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
with torch.device("meta"):
|
||||
model = Transformer(model_args)
|
||||
for module in model.modules():
|
||||
@ -763,7 +786,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
for (param_name, full_param), sharded_meta_param in zip(
|
||||
full_sd.items(), meta_sharded_sd.values()
|
||||
):
|
||||
full_param = full_param.detach().cuda()
|
||||
full_param = full_param.detach().to(device_type)
|
||||
mesh = sharded_meta_param.device_mesh
|
||||
dist.broadcast(full_param, src=0, group=mesh.get_group(0))
|
||||
sharded_tensor = distribute_tensor(
|
||||
@ -774,7 +797,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
for param_name, sharded_meta_param in meta_sharded_sd.items():
|
||||
full_tensor = torch.empty(
|
||||
sharded_meta_param.size(),
|
||||
device="cuda",
|
||||
device=device_type.type,
|
||||
dtype=sharded_meta_param.dtype,
|
||||
)
|
||||
mesh = sharded_meta_param.device_mesh
|
||||
@ -787,7 +810,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
model.load_state_dict(sharded_sd, assign=True)
|
||||
for param in model.parameters():
|
||||
self.assertIsInstance(param, DTensor)
|
||||
self.assertEqual(param.device.type, "cuda")
|
||||
self.assertEqual(param.device.type, device_type.type)
|
||||
|
||||
# Construct the reference model on nonzero ranks by broadcasting the
|
||||
# unsharded model from rank 0 and sharding on all ranks
|
||||
@ -807,7 +830,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
|
||||
self.assertEqual(param, ref_param)
|
||||
|
||||
# Check one forward/backward for parity
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
ref_loss = ref_model(inp).sum()
|
||||
@ -822,20 +845,22 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_1d_process_group_init(self):
|
||||
assert self.world_size == 4, f"{self.world_size}"
|
||||
# For convenience, use device mesh's infra to construct the DP PG
|
||||
# (in practice, the trainer would do it manually via `new_group()`)
|
||||
dp_size = 2
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
device_type.type,
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
ref_dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
||||
dp_pg = ref_dp_mesh.get_group(0)
|
||||
|
||||
# Check the `from_group()` API for correctness
|
||||
dp_mesh = DeviceMesh.from_group(dp_pg, "cuda", mesh_dim_names=("dp",))
|
||||
dp_mesh = DeviceMesh.from_group(dp_pg, device_type.type, mesh_dim_names=("dp",))
|
||||
# Only compare the mesh tensors, not `DeviceMesh` objects themselves,
|
||||
# since the ref has a parent mesh, while the `from_group` one does not
|
||||
self.assertEqual(dp_mesh.mesh, ref_dp_mesh.mesh)
|
||||
@ -860,7 +885,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
||||
fully_shard(module, mesh=dp_mesh)
|
||||
|
||||
# Ensure that TP ranks have the same input
|
||||
inp = torch.randn((4, mlp_dim), device="cuda")
|
||||
inp = torch.randn((4, mlp_dim), device=device_type.type)
|
||||
if self.rank in (0, 1):
|
||||
dist.broadcast(inp, src=0, group=tp_mesh.get_group(0))
|
||||
elif self.rank in (2, 3):
|
||||
@ -882,7 +907,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
||||
param.grad.device_mesh.mesh, ref_param.grad.device_mesh.mesh
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_2d_process_group_init(self):
|
||||
shard_mesh_dim_size = 2
|
||||
assert (
|
||||
@ -891,7 +916,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
||||
replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size
|
||||
mesh_dim_names = ("replicate", "shard")
|
||||
ref_mesh = init_device_mesh(
|
||||
"cuda",
|
||||
device_type.type,
|
||||
(replicate_mesh_dim_size, shard_mesh_dim_size),
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
)
|
||||
@ -910,7 +935,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
||||
# Check the `from_group()` API for correctness
|
||||
mesh = DeviceMesh.from_group(
|
||||
[dp_replicate_group, dp_shard_group],
|
||||
"cuda",
|
||||
device_type.type,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
mesh=mesh_tensor,
|
||||
)
|
||||
@ -943,7 +968,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
||||
for module in (model.in_proj, model.out_proj, model):
|
||||
fully_shard(module, mesh=mesh)
|
||||
|
||||
inp = torch.randn((4, mlp_dim), device="cuda")
|
||||
inp = torch.randn((4, mlp_dim), device=device_type.type)
|
||||
ref_loss = ref_model(inp).sum()
|
||||
ref_loss.backward()
|
||||
loss = model(inp).sum()
|
||||
@ -959,11 +984,13 @@ class TestFullyShardHSDPBroadcast(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_hsdp_broadcast_across_replicas(self):
|
||||
shard_size, replicate_size = 2, 2
|
||||
mesh = init_device_mesh(
|
||||
"cuda", (replicate_size, shard_size), mesh_dim_names=("replicate", "shard")
|
||||
device_type.type,
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
)
|
||||
model_args = ModelArgs()
|
||||
model = Transformer(model_args)
|
||||
@ -1017,7 +1044,7 @@ class TestFullyShardHSDPBroadcast(FSDPTestMultiThread):
|
||||
self.assertEqual(other_local_tensor, local_tensor_list[0])
|
||||
|
||||
# Check that we can run an iteration without erroring
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
model(inp).sum().backward()
|
||||
|
||||
|
||||
@ -1028,17 +1055,17 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
|
||||
|
||||
def perThreadSetUp(self) -> None:
|
||||
super().perThreadSetUp()
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_device(device_type)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_custom_hook_custom_stream(self):
|
||||
hsdp_mesh = init_device_mesh(
|
||||
"cuda", (2, 2), mesh_dim_names=("replicate", "shard")
|
||||
device_type.type, (2, 2), mesh_dim_names=("replicate", "shard")
|
||||
)
|
||||
model = MLP(10, bias=False)
|
||||
fully_shard(model, mesh=hsdp_mesh)
|
||||
model = cast(FSDPModule, model)
|
||||
custom_stream = torch.cuda.Stream()
|
||||
custom_stream = torch.get_device_module(device_type).Stream()
|
||||
|
||||
# native HSDP should reject
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
@ -1051,7 +1078,7 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
|
||||
intra_pg = _init_intra_node_process_group(2)
|
||||
fsdp_mesh = DeviceMesh.from_group(
|
||||
intra_pg,
|
||||
"cuda",
|
||||
device_type.type,
|
||||
dist.get_process_group_ranks(intra_pg),
|
||||
mesh_dim_names=("shard",),
|
||||
)
|
||||
@ -1059,7 +1086,7 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
|
||||
|
||||
def _hook(_output: torch.Tensor) -> None:
|
||||
nonlocal hook_used_stream
|
||||
hook_used_stream = torch.cuda.current_stream()
|
||||
hook_used_stream = torch.get_device_module(device_type).current_stream()
|
||||
|
||||
model = MLP(10, bias=False)
|
||||
fully_shard(model, mesh=fsdp_mesh)
|
||||
@ -1069,17 +1096,17 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
|
||||
inp = torch.arange(10, dtype=torch.float32, requires_grad=True).view(1, 10)
|
||||
out = model(inp)
|
||||
out.sum().backward()
|
||||
torch.cuda.synchronize()
|
||||
torch.get_device_module(device_type).synchronize()
|
||||
self.assertEqual(hook_used_stream, custom_stream)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_custom_hsdp_all_reduce_hook(self):
|
||||
world_pg = dist.distributed_c10d._get_default_group()
|
||||
intra_pg = _init_intra_node_process_group(2)
|
||||
inter_pg = _init_inter_node_process_group(world_pg, 2)
|
||||
mesh = DeviceMesh.from_group(
|
||||
intra_pg,
|
||||
"cuda",
|
||||
device_type.type,
|
||||
dist.get_process_group_ranks(intra_pg),
|
||||
mesh_dim_names=("shard",),
|
||||
)
|
||||
@ -1106,7 +1133,7 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
|
||||
inp = torch.arange(10, dtype=torch.float32, requires_grad=True).view(1, 10)
|
||||
out = model(inp)
|
||||
out.sum().backward()
|
||||
torch.cuda.synchronize()
|
||||
torch.get_device_module(device_type).synchronize()
|
||||
# custom hook was fired
|
||||
self.assertTrue(hook_called)
|
||||
# within each replica, FSDP shards the weights at dim 0
|
||||
@ -1140,7 +1167,7 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
|
||||
ref_model = copy.deepcopy(model)
|
||||
return model, ref_model
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_init_1d_transformer_shard_largest_dim(self):
|
||||
model, ref_model = self._init_models()
|
||||
|
||||
@ -1168,7 +1195,7 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
|
||||
full_param = param.full_tensor()
|
||||
self.assertEqual(full_param, ref_param)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_init_1d_transformer_shard_dim_neg1(self):
|
||||
model, ref_model = self._init_models()
|
||||
|
||||
@ -1184,13 +1211,13 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
|
||||
full_param = param.full_tensor()
|
||||
self.assertEqual(full_param, ref_param)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_init_2d_transformer_shard_diff_dim(self):
|
||||
model, ref_model = self._init_models()
|
||||
|
||||
dp_size, tp_size = self.world_size // 2, 2
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")
|
||||
device_type.type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True)
|
||||
|
||||
@ -1234,7 +1261,7 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
|
||||
full_param = param.full_tensor()
|
||||
self.assertEqual(full_param, ref_param)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_init_1d_uneven_shard_largest_dim(self):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(nn.Linear(16, 17), nn.Linear(17, 8))
|
||||
@ -1255,7 +1282,7 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
|
||||
):
|
||||
fully_shard(model, shard_placement_fn=shard_placement_fn)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_invalid_shard_dim(self):
|
||||
model = nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 8))
|
||||
|
||||
@ -1276,7 +1303,7 @@ class TestFullyShardOldImport(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_old_import_training(self):
|
||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
|
||||
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
|
||||
@ -1291,7 +1318,7 @@ class TestFullyShardOldImport(FSDPTestMultiThread):
|
||||
self.assertIsInstance(model[1], FSDPModule)
|
||||
self.assertIsInstance(model, FSDPModule)
|
||||
|
||||
inp = torch.randn((8, 16), device="cuda")
|
||||
inp = torch.randn((8, 16), device=device_type)
|
||||
model(inp).sum().backward()
|
||||
|
||||
|
||||
|
||||
@ -15,6 +15,12 @@ requires_distributed = functools.partial(
|
||||
unittest.skipIf, not dist.is_available(), "requires distributed"
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
class LoggingTests(LoggingTestCase):
|
||||
@ -27,7 +33,7 @@ class LoggingTests(LoggingTestCase):
|
||||
env["MASTER_PORT"] = "34715"
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
_, stderr = self.run_process_no_exception(
|
||||
"""\
|
||||
f"""\
|
||||
import logging
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -35,7 +41,7 @@ import torch.nn as nn
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
device = "cuda"
|
||||
device = {device_type.type}
|
||||
torch.manual_seed(0)
|
||||
model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)])
|
||||
for layer in model:
|
||||
|
||||
@ -2,12 +2,13 @@
|
||||
|
||||
import functools
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, OffloadPolicy
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
ModelArgs,
|
||||
Transformer,
|
||||
@ -15,12 +16,16 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestFullyShardMemory(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(2, torch.cuda.device_count())
|
||||
return min(2, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(TEST_HPU, " 'empty_cache' is not supported on hpu")
|
||||
def test_fully_shard_training_memory(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
@ -56,10 +61,10 @@ class TestFullyShardMemory(FSDPTest):
|
||||
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
|
||||
# allocate the cuBLAS workspaces before measuring the memory usage
|
||||
# since the workspace size can differ between hardwares
|
||||
lin = torch.nn.Linear(768, 768, device="cuda")
|
||||
inp = torch.randn(1, 768, device="cuda")
|
||||
lin = torch.nn.Linear(768, 768, device=device_type)
|
||||
inp = torch.randn(1, 768, device=device_type)
|
||||
lin(inp).sum().backward()
|
||||
torch.cuda.empty_cache()
|
||||
torch.get_device_module(device_type).empty_cache()
|
||||
base_mem_mb = self._get_peak_active_memory_mb()
|
||||
vocab_size = 32
|
||||
model_args = ModelArgs(
|
||||
@ -108,7 +113,7 @@ class TestFullyShardMemory(FSDPTest):
|
||||
self.assertLessEqual(curr_mem_mb - base_mem_mb, init_mem_mb)
|
||||
|
||||
# Use a small input to minimize activation memory usage
|
||||
inp = torch.randint(0, vocab_size, (1, 4), device="cuda")
|
||||
inp = torch.randint(0, vocab_size, (1, 4), device=device_type.type)
|
||||
|
||||
# Forward:
|
||||
loss = model(inp)
|
||||
@ -169,7 +174,7 @@ class TestFullyShardMemory(FSDPTest):
|
||||
) * 4 / 1e6 + buffer_mb
|
||||
self.assertLessEqual(mem_mb - base_mem_mb, expected_mem_mb)
|
||||
del loss
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.get_device_module(device_type).reset_peak_memory_stats()
|
||||
|
||||
# Optimizer step: unsharded parameters/gradients freed
|
||||
if not run_optim_in_backward:
|
||||
@ -187,7 +192,9 @@ class TestFullyShardMemory(FSDPTest):
|
||||
# Zero grad: sharded gradients freed
|
||||
if not run_optim_in_backward:
|
||||
optim.zero_grad()
|
||||
torch.cuda.reset_peak_memory_stats() # reset after freeing
|
||||
torch.get_device_module(
|
||||
device_type
|
||||
).reset_peak_memory_stats() # reset after freeing
|
||||
mem_mb = self._get_peak_active_memory_mb()
|
||||
expected_mem_mb = 0
|
||||
if not use_cpu_offload:
|
||||
@ -228,12 +235,18 @@ class TestFullyShardMemory(FSDPTest):
|
||||
self.assertEqual(mem_mb, base_mem_mb)
|
||||
|
||||
def _get_peak_active_memory_mb(self) -> int:
|
||||
mem_stats = torch.cuda.memory_stats()
|
||||
return round(mem_stats["active_bytes.all.peak"] / 1e6)
|
||||
mem_stats = torch.get_device_module(device_type).memory_stats()
|
||||
if TEST_CUDA:
|
||||
return round(mem_stats["active_bytes.all.peak"] / 1e6)
|
||||
if TEST_HPU:
|
||||
return round(mem_stats["MaxInUse"] / 1e6)
|
||||
|
||||
def _get_curr_active_memory_mb(self) -> int:
|
||||
mem_stats = torch.cuda.memory_stats()
|
||||
return round(mem_stats["active_bytes.all.current"] / 1e6)
|
||||
mem_stats = torch.get_device_module(device_type).memory_stats()
|
||||
if TEST_CUDA:
|
||||
return round(mem_stats["active_bytes.all.current"] / 1e6)
|
||||
if TEST_HPU:
|
||||
return round(mem_stats["InUse"] / 1e6)
|
||||
|
||||
def _register_optim_in_backward(
|
||||
self, model: torch.nn.Module, **optim_kwargs
|
||||
|
||||
@ -22,17 +22,21 @@ from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
FSDPTest,
|
||||
FSDPTestMultiThread,
|
||||
get_devtype,
|
||||
MLP,
|
||||
patch_reduce_scatter,
|
||||
reduce_scatter_with_assert,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm, TEST_HPU
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
def _init_models_and_optims(
|
||||
self,
|
||||
@ -43,7 +47,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
||||
@ -123,7 +127,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||
)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
fsdp_loss = model(inp).sum()
|
||||
@ -209,7 +213,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
fsdp_loss = model(inp).sum()
|
||||
@ -258,7 +262,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
|
||||
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
fsdp_loss = model(inp).sum()
|
||||
@ -309,7 +313,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||
# To emulate the mixed precision implementation where forward/backward
|
||||
# compute use bf16 and optimizer uses fp32, we maintain both an fp32
|
||||
# and a bf16 copy of the reference model
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for mlp in model:
|
||||
@ -329,7 +333,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
device = torch.device("cuda")
|
||||
device = device_type
|
||||
# Train on the same input to avoid loss explosion
|
||||
num_microbatches = 4
|
||||
inp = torch.randn((2 * num_microbatches, 16), device=device, dtype=param_dtype)
|
||||
@ -389,7 +393,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_float16_on_one_submodule(self):
|
||||
x = torch.zeros(2, 100, device="cuda")
|
||||
x = torch.zeros(2, 100, device=device_type)
|
||||
|
||||
# Subtest 1: use fp16 on the second child submodule -- does not require
|
||||
# any additional casting logic
|
||||
@ -397,7 +401,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs,
|
||||
cast_forward_inputs=False,
|
||||
).cuda()
|
||||
).to(device_type)
|
||||
fully_shard(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
|
||||
fully_shard(model)
|
||||
model(x).sum().backward()
|
||||
@ -410,7 +414,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=True
|
||||
).cuda()
|
||||
).to(device_type)
|
||||
fully_shard(
|
||||
model.c2,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
@ -428,7 +432,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=False
|
||||
).cuda()
|
||||
).to(device_type)
|
||||
fully_shard(
|
||||
model.c1,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
@ -470,13 +474,13 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self.forward_inputs["model_input_x"] = x
|
||||
y = torch.ones(
|
||||
2, 100, device="cuda", dtype=torch.float32
|
||||
2, 100, device=device_type.type, dtype=torch.float32
|
||||
) # external input
|
||||
return self.l2(self.l1(x), y)
|
||||
|
||||
forward_inputs: dict[str, torch.Tensor] = {}
|
||||
model = ToyModel(forward_inputs).cuda()
|
||||
x = torch.zeros(2, 100, device="cuda", dtype=torch.float32)
|
||||
model = ToyModel(forward_inputs).to(device_type)
|
||||
x = torch.zeros(2, 100, device=device_type.type, dtype=torch.float32)
|
||||
fully_shard(
|
||||
model.l2,
|
||||
mp_policy=MixedPrecisionPolicy(
|
||||
@ -529,9 +533,15 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
for module in (model[0], model[1], model[2], model):
|
||||
fully_shard(module, mp_policy=mp_policy)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected running_mean to have type"):
|
||||
# Errors in batch norm 2D backward
|
||||
if TEST_HPU:
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
else:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected running_mean to have type", # Error not seen on HPUs and hence it can be skipped
|
||||
):
|
||||
# Errors in batch norm 2D backward
|
||||
inner(model, torch.randn((3, 1, 9, 9)))
|
||||
|
||||
# Batch norm 2D: cast buffers down to lower precision
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
|
||||
@ -557,7 +567,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(32, 32, dtype=init_dtype),
|
||||
nn.Linear(32, 32, dtype=init_dtype),
|
||||
)
|
||||
).to(device_type.type)
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16
|
||||
)
|
||||
@ -579,7 +589,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
||||
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
|
||||
)
|
||||
with patch_reduce_scatter(reduce_scatter):
|
||||
inp = torch.randn((4, 32), device="cuda")
|
||||
inp = torch.randn((4, 32), device=device_type.type)
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import unittest
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
@ -12,10 +13,15 @@ from torch.distributed.tensor.experimental import implicit_replication
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
get_devtype,
|
||||
patch_all_gather,
|
||||
patch_reduce_scatter,
|
||||
)
|
||||
from torch.testing._internal.common_utils import get_cycles_per_ms, run_tests
|
||||
from torch.testing._internal.common_utils import get_cycles_per_ms, run_tests, TEST_HPU
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
device_module = torch.get_device_module(device_type)
|
||||
|
||||
|
||||
class TestFullyShardOverlap(FSDPTest):
|
||||
@ -35,9 +41,10 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(2, torch.cuda.device_count())
|
||||
return min(2, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
|
||||
def test_fully_shard_training_overlap(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
@ -46,7 +53,7 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
model = nn.Sequential(
|
||||
*[LinearWithSleep(dim, compute_sleep_ms) for _ in range(num_linears)]
|
||||
)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
for lin in model:
|
||||
assert len(list(lin.parameters())) == 1, "Expects only one weight"
|
||||
fully_shard(lin, reshard_after_forward=True)
|
||||
@ -54,15 +61,21 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
|
||||
orig_all_gather_into_tensor = dist.all_gather_into_tensor
|
||||
orig_reduce_scatter_tensor = dist.reduce_scatter_tensor
|
||||
comm_stream = torch.cuda.Stream()
|
||||
comm_stream = torch.get_device_module(device_type).Stream()
|
||||
|
||||
def delay_collective():
|
||||
# Share a stream so that all-gather and reduce-scatter block each
|
||||
# other like in `ProcessGroupNCCL`
|
||||
comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(comm_stream):
|
||||
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
|
||||
torch.cuda.current_stream().wait_stream(comm_stream)
|
||||
comm_stream.wait_stream(
|
||||
torch.get_device_module(device_type).current_stream()
|
||||
)
|
||||
with torch.get_device_module(device_type).stream(comm_stream):
|
||||
torch.get_device_module(device_type)._sleep(
|
||||
int(comm_sleep_ms * get_cycles_per_ms())
|
||||
)
|
||||
torch.get_device_module(device_type).current_stream().wait_stream(
|
||||
comm_stream
|
||||
)
|
||||
|
||||
def delayed_all_gather(*args, **kwargs):
|
||||
delay_collective()
|
||||
@ -72,7 +85,7 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
delay_collective()
|
||||
return orig_reduce_scatter_tensor(*args, **kwargs)
|
||||
|
||||
inp = torch.randn((2, dim), device="cuda")
|
||||
inp = torch.randn((2, dim), device=device_type.type)
|
||||
loss = model(inp).sum() # warmup CUDA and allocator
|
||||
loss.backward()
|
||||
|
||||
@ -144,6 +157,7 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
self.assertLessEqual(fwd_bwd_time, ref_fwd_bwd_time)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
|
||||
def test_fully_shard_post_optim_event_overlap(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
@ -153,17 +167,19 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
# low-compute linear, where only the low-compute linear uses FSDP
|
||||
model = nn.Sequential(
|
||||
LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim)
|
||||
).cuda()
|
||||
).to(device_type)
|
||||
fully_shard(model[1], reshard_after_forward=False)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
orig_all_gather_into_tensor = dist.all_gather_into_tensor
|
||||
|
||||
def delayed_all_gather(*args, **kwargs):
|
||||
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
|
||||
torch.get_device_module(device_type)._sleep(
|
||||
int(comm_sleep_ms * get_cycles_per_ms())
|
||||
)
|
||||
return orig_all_gather_into_tensor(*args, **kwargs)
|
||||
|
||||
inp = torch.randn((2, dim), device="cuda")
|
||||
inp = torch.randn((2, dim), device=device_type)
|
||||
|
||||
def run_train_steps(num_iters: int, use_post_optim_event: bool):
|
||||
for _ in range(num_iters):
|
||||
@ -174,7 +190,11 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
with implicit_replication():
|
||||
optim.step()
|
||||
if use_post_optim_event:
|
||||
post_optim_event = torch.cuda.current_stream().record_event()
|
||||
post_optim_event = (
|
||||
torch.get_device_module(device_type)
|
||||
.current_stream()
|
||||
.record_event()
|
||||
)
|
||||
model[1].set_post_optim_event(post_optim_event)
|
||||
|
||||
run_train_steps(1, False) # warmup CUDA and allocator
|
||||
@ -205,14 +225,14 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
self.assertGreater(baseline_time, test_time)
|
||||
|
||||
def _time_fn(self, fn: Callable):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event = device_module.Event(enable_timing=True)
|
||||
end_event = device_module.Event(enable_timing=True)
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
device_module.synchronize()
|
||||
start_event.record()
|
||||
fn()
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
device_module.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
return elapsed_time
|
||||
|
||||
@ -223,13 +243,15 @@ class Matmul(torch.autograd.Function):
|
||||
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, sleep_ms: int):
|
||||
ctx.save_for_backward(input, weight)
|
||||
ctx.sleep_ms = sleep_ms
|
||||
torch.cuda._sleep(int(sleep_ms * get_cycles_per_ms()))
|
||||
torch.get_device_module(device_type)._sleep(int(sleep_ms * get_cycles_per_ms()))
|
||||
return input @ weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: torch.Tensor):
|
||||
(input, weight) = ctx.saved_tensors
|
||||
torch.cuda._sleep(int(2 * ctx.sleep_ms * get_cycles_per_ms()))
|
||||
torch.get_device_module(device_type)._sleep(
|
||||
int(2 * ctx.sleep_ms * get_cycles_per_ms())
|
||||
)
|
||||
grad_input = grad_output @ weight.T
|
||||
grad_weight = input.T @ grad_output
|
||||
return grad_input, grad_weight, None
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import FSDPModule, fully_shard
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
@ -15,7 +14,7 @@ class TestFullyShardState(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 1
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fully_shard_state(self):
|
||||
"""
|
||||
Tests the ability to get the state object from a fully sharded module.
|
||||
@ -31,7 +30,7 @@ class TestFullyShardState(FSDPTestMultiThread):
|
||||
# Check that each `fully_shard` call constructs a distinct state object
|
||||
self.assertEqual(len(set(all_states)), num_mlps + 1)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fully_shard_reapply(self):
|
||||
model = MLP(8)
|
||||
fully_shard(model)
|
||||
@ -41,7 +40,7 @@ class TestFullyShardState(FSDPTestMultiThread):
|
||||
):
|
||||
fully_shard(model)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fully_shard_cls(self):
|
||||
# Check that we only swap class for the module passed to `fully_shard`
|
||||
model = MLP(8)
|
||||
@ -64,7 +63,7 @@ class TestFullyShardState(FSDPTestMultiThread):
|
||||
self.assertTrue(isinstance(sliced_model, nn.Sequential))
|
||||
self.assertFalse(isinstance(sliced_model, FSDPModule))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fully_shard_unsupported_module_cls(self):
|
||||
regex = (
|
||||
r"fully\_shard does not support containers that do not implement forward"
|
||||
@ -76,7 +75,7 @@ class TestFullyShardState(FSDPTestMultiThread):
|
||||
with self.assertRaisesRegex(ValueError, regex):
|
||||
fully_shard(model)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fully_shard_deepcopy(self):
|
||||
model = MLP(8)
|
||||
fully_shard(model)
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import unittest
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
@ -16,9 +15,13 @@ from torch.distributed.tensor.parallel import (
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, FSDPTestMultiThread, MLP
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
FSDPTestMultiThread,
|
||||
get_devtype,
|
||||
MLP,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
ModelArgs,
|
||||
@ -27,14 +30,17 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
)
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestFullyShardStateDictMultiProcess(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(8, torch.cuda.device_count())
|
||||
return min(8, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_dp_state_dict_save_load(self):
|
||||
fsdp_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
fsdp_mesh = init_device_mesh(device_type.type, (self.world_size,))
|
||||
self.run_subtests(
|
||||
{"mlp_dim": [2, 3, 4, 5], "mesh": [fsdp_mesh]},
|
||||
self._test_dp_state_dict_save_load,
|
||||
@ -53,7 +59,7 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
|
||||
if self.world_size % 2 != 0:
|
||||
return
|
||||
hsdp_mesh = init_device_mesh(
|
||||
"cuda",
|
||||
device_type.type,
|
||||
(self.world_size // 2, 2),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard"),
|
||||
)
|
||||
@ -103,7 +109,7 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
|
||||
fully_shard_fn(model2, reshard_after_forward=False)
|
||||
self._test_state_dict_save_load(model2)
|
||||
ref_sharded_sd = model2.state_dict()
|
||||
inp = torch.randn((2, mlp_dim), device="cuda")
|
||||
inp = torch.randn((2, mlp_dim), device=device_type.type)
|
||||
model2(inp) # parameters are not resharded after this forward
|
||||
# Check that state dict hooks reshard
|
||||
sharded_sd = model2.state_dict()
|
||||
@ -155,12 +161,12 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
|
||||
model.load_state_dict(sd, assign=True, strict=False)
|
||||
|
||||
# lazy init without error
|
||||
inp = torch.rand((mlp_dim, mlp_dim), device="cuda")
|
||||
inp = torch.rand((mlp_dim, mlp_dim), device=device_type.type)
|
||||
|
||||
context = (
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Found following parameters on non-CPU device: \[\('0.weight', device\(type='cuda'",
|
||||
rf"Found following parameters on non-CPU device: \[\('0.weight', device\(type='{device_type.type}'",
|
||||
)
|
||||
if not cpu_state_dict
|
||||
else nullcontext()
|
||||
@ -171,10 +177,13 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
|
||||
for name, dtensor in state_dict.items():
|
||||
self.assertEqual(dtensor.device.type, "cpu")
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_2d_state_dict_correctness(self):
|
||||
dp_size = 2
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
device_type.type,
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
||||
torch.manual_seed(42)
|
||||
@ -214,7 +223,9 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
|
||||
def test_dp_tp_state_dict_save_load(self):
|
||||
dp_size = 2
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
device_type.type,
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
self.run_subtests(
|
||||
{"mlp_dim": [4, 6, 8, 10]},
|
||||
@ -245,7 +256,7 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_hsdp_tp_state_dict_save_load(self):
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda",
|
||||
device_type.type,
|
||||
(2, 2, self.world_size // 4),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
|
||||
)
|
||||
@ -345,12 +356,12 @@ class TestFullyShardStateDictMultiThread(FSDPTestMultiThread):
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_rank0_offload_full_state_dict(self):
|
||||
# Construct a reference unsharded model on all ranks
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
torch.manual_seed(42)
|
||||
ref_model = Transformer(model_args).cuda()
|
||||
ref_model = Transformer(model_args).to(device_type)
|
||||
for param in ref_model.parameters():
|
||||
torch.distributed.broadcast(param.detach(), src=0)
|
||||
|
||||
|
||||
@ -27,7 +27,6 @@ from torch.distributed.fsdp import (
|
||||
)
|
||||
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
check_sharded_parity,
|
||||
@ -42,6 +41,7 @@ from torch.testing._internal.common_fsdp import (
|
||||
from torch.testing._internal.common_utils import (
|
||||
get_cycles_per_ms,
|
||||
run_tests,
|
||||
TEST_HPU,
|
||||
wrapSwapTensorsTest,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
@ -54,15 +54,20 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
c10d_ops = torch.ops.c10d
|
||||
funcol = torch.ops.c10d_functional
|
||||
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
|
||||
|
||||
device_type = torch.device(get_devtype())
|
||||
|
||||
|
||||
class TestFullyShardForwardInputs(FSDPTestMultiThread):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_root_move_forward_input_to_device(self):
|
||||
device = torch.device("cuda", 0)
|
||||
device = torch.device(device_type.type, 0)
|
||||
|
||||
class ParamlessModule(nn.Module):
|
||||
def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]):
|
||||
@ -78,8 +83,8 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread):
|
||||
y = ys[0] + ys[1]
|
||||
return x + y + 1
|
||||
|
||||
model = ParamlessModule()
|
||||
fully_shard(model)
|
||||
model = ParamlessModule().to(device)
|
||||
fully_shard(model).to(device)
|
||||
x = torch.randn((3,))
|
||||
ys = (torch.randn((3,)), torch.randn((3,)))
|
||||
self.assertEqual(x.device, torch.device("cpu"))
|
||||
@ -93,10 +98,10 @@ class TestFullyShardRegisteredParams(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_param_registration_after_forward(self):
|
||||
"""Tests the parameter registration after forward."""
|
||||
device = torch.device("cuda", 0)
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single FSDP group
|
||||
for reshard_after_forward in (True, False, 2):
|
||||
torch.manual_seed(42)
|
||||
@ -107,7 +112,7 @@ class TestFullyShardRegisteredParams(FSDPTestMultiThread):
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 3), device="cuda")
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp) # root does not reshard after forward
|
||||
@ -147,15 +152,15 @@ class TestFullyShardRegisteredParams(FSDPTestMultiThread):
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_param_registration_after_backward(self):
|
||||
"""Tests the parameter registration after backward."""
|
||||
device = torch.device("cuda", 0)
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single FSDP group
|
||||
for reshard_after_forward in (True, False, 2):
|
||||
model = MLP(8, device)
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 8), device="cuda")
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
@ -198,14 +203,14 @@ class TestFullyShardCastAfterInit(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
@wrapSwapTensorsTest(True)
|
||||
def test_to_float64_after_init(self):
|
||||
"""Tests that the user can cast the module to float64 after init."""
|
||||
# NOTE: Test fp64 instead of a lower precision dtype like bf16 for
|
||||
# better numerics. The important part is changing the dtype.
|
||||
torch.manual_seed(42)
|
||||
mlp_dim, device, dtype = 4, torch.device("cuda"), torch.float64
|
||||
mlp_dim, device, dtype = 4, device_type, torch.float64
|
||||
model = MLP(mlp_dim, device=device)
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
@ -222,7 +227,7 @@ class TestFullyShardCastAfterInit(FSDPTestMultiThread):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((2, mlp_dim), device="cuda", dtype=dtype)
|
||||
inp = torch.randn((2, mlp_dim), device=device_type.type, dtype=dtype)
|
||||
for iter_idx in range(10):
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
@ -245,7 +250,7 @@ class TestFullyShardCastAfterInit(FSDPTestMultiThread):
|
||||
class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(8, torch.cuda.device_count())
|
||||
return min(8, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_parity_single_group_shard_dim0(self):
|
||||
@ -287,7 +292,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
|
||||
)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate(ref_model, device_ids=[self.rank])
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
@ -298,7 +303,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
fully_shard(model, shard_placement_fn=shard_placement_fn)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = (torch.randn((4, lin_shapes[0][0]), device="cuda"),)
|
||||
inp = (torch.randn((4, lin_shapes[0][0]), device=device_type.type),)
|
||||
for iter_idx in range(10):
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
@ -309,6 +314,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(TEST_HPU, "Sleep kernel not supported for HPU")
|
||||
@compiled_fsdp_test(compile_compute_on_module=Transformer)
|
||||
def test_train_parity_multi_group(self):
|
||||
"""
|
||||
@ -319,7 +325,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False, 2],
|
||||
"device_type": ["cuda"],
|
||||
"device_type": [device_type.type],
|
||||
"offload_policy": [OffloadPolicy()],
|
||||
"delay_after_forward": [False, True],
|
||||
"delay_before_all_gather": [False, True],
|
||||
@ -331,6 +337,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(TEST_HPU, "sleep kernel not supported on HPU")
|
||||
def test_train_parity_multi_group_cpu_offload_eager(self):
|
||||
"""
|
||||
Tests train parity against DDP when using multiple parameter groups for
|
||||
@ -343,7 +350,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
CPUOffloadPolicy(pin_memory=True),
|
||||
CPUOffloadPolicy(pin_memory=False),
|
||||
],
|
||||
"device_type": ["cuda"],
|
||||
"device_type": [device_type.type],
|
||||
"delay_after_forward": [False, True],
|
||||
"delay_before_all_gather": [False, True],
|
||||
"delay_before_reduce_scatter": [False, True],
|
||||
@ -354,6 +361,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(TEST_HPU, "sleep kernel not supported on HPU")
|
||||
@compiled_fsdp_test(compile_compute_on_module=Transformer)
|
||||
def test_train_parity_multi_group_unshard_async_op(self):
|
||||
"""
|
||||
@ -363,7 +371,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True],
|
||||
"device_type": ["cuda"],
|
||||
"device_type": [device_type.type],
|
||||
"offload_policy": [OffloadPolicy()],
|
||||
"delay_after_forward": [False, True],
|
||||
"delay_before_all_gather": [False, True],
|
||||
@ -394,7 +402,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
in (2, 3)
|
||||
):
|
||||
return
|
||||
assert device_type in ("cuda", "cpu"), f"{device_type}"
|
||||
assert device_type in ("cuda", "hpu", "xpu", "cpu"), f"{device_type}"
|
||||
torch.manual_seed(42)
|
||||
vocab_size = 1024
|
||||
model_args = ModelArgs(
|
||||
@ -406,8 +414,11 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
)
|
||||
model = Transformer(model_args)
|
||||
ref_model = copy.deepcopy(model)
|
||||
if device_type == "cuda":
|
||||
replicate(ref_model.cuda(), device_ids=[self.rank])
|
||||
if device_type == device_type:
|
||||
replicate(
|
||||
ref_model.to(device_type),
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
else:
|
||||
gloo_pg = dist.new_group(backend="gloo")
|
||||
replicate(ref_model, process_group=gloo_pg)
|
||||
@ -432,11 +443,15 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
def delayed_all_gather(*args, **kwargs):
|
||||
torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
|
||||
torch.get_device_module(device_type)._sleep(
|
||||
int(delay_in_ms * get_cycles_per_ms())
|
||||
)
|
||||
return orig_all_gather(*args, **kwargs)
|
||||
|
||||
def delayed_reduce_scatter(*args, **kwargs):
|
||||
torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
|
||||
torch.get_device_module(device_type)._sleep(
|
||||
int(delay_in_ms * get_cycles_per_ms())
|
||||
)
|
||||
return orig_reduce_scatter(*args, **kwargs)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
@ -458,10 +473,14 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
if _model is model and delay_after_forward:
|
||||
torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
|
||||
torch.get_device_module(device_type)._sleep(
|
||||
int(delay_in_ms * get_cycles_per_ms())
|
||||
)
|
||||
losses[-1].backward()
|
||||
if _model is model and delay_before_optim:
|
||||
torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
|
||||
torch.get_device_module(device_type)._sleep(
|
||||
int(delay_in_ms * get_cycles_per_ms())
|
||||
)
|
||||
_optim.step()
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
@ -474,14 +493,14 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
torch.manual_seed(42)
|
||||
lin_dim = 32
|
||||
model = nn.Sequential(*[MLP(lin_dim, torch.device("cpu")) for _ in range(3)])
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for mlp in model:
|
||||
fully_shard(mlp)
|
||||
fully_shard(model)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randn((8, lin_dim), device=torch.device("cuda"))
|
||||
inp = torch.randn((8, lin_dim), device=device_type)
|
||||
|
||||
ref_root_loss = ref_model(inp).sum()
|
||||
ref_root_loss.backward()
|
||||
@ -500,7 +519,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
|
||||
root_loss = model(inp).sum()
|
||||
root_loss.backward()
|
||||
torch.cuda._sleep(int(100 * get_cycles_per_ms()))
|
||||
torch.get_device_module(device_type)._sleep(int(100 * get_cycles_per_ms()))
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
nonroot_loss = model[0](inp).sum()
|
||||
@ -535,16 +554,19 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
return self.outer(i + j)
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = MultiForwardModule(device="cuda")
|
||||
model = MultiForwardModule(device=device_type.type)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(ref_model, device_ids=[self.rank])
|
||||
replicate(
|
||||
ref_model,
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
fully_shard(model.inner)
|
||||
fully_shard(model)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randn((32, 4), device="cuda")
|
||||
inp = torch.randn((32, 4), device=device_type.type)
|
||||
for iter_idx in range(10):
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
@ -559,7 +581,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(n_layers=8, dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
ref_model = replicate(copy.deepcopy(model).cuda())
|
||||
ref_model = replicate(copy.deepcopy(model).to(device_type))
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for layer in itertools.chain(model.layers, [model]):
|
||||
fully_shard(layer)
|
||||
@ -582,7 +604,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 8), device=device_type.type)
|
||||
for _ in range(10):
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
@ -593,11 +615,12 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
|
||||
def test_post_optim_event(self):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
ref_model = replicate(copy.deepcopy(model).cuda())
|
||||
ref_model = replicate(copy.deepcopy(model).to(device_type.type))
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for layer in itertools.chain(model.layers, [model]):
|
||||
fully_shard(layer)
|
||||
@ -606,13 +629,15 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
def step_post_hook(
|
||||
fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs
|
||||
) -> None:
|
||||
post_optim_event = torch.cuda.current_stream().record_event()
|
||||
post_optim_event = (
|
||||
torch.get_device_module(device_type).current_stream().record_event()
|
||||
)
|
||||
fsdp_module.set_post_optim_event(post_optim_event)
|
||||
|
||||
optim.register_step_post_hook(functools.partial(step_post_hook, model))
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 8), device=device_type.type)
|
||||
# Track all losses and check for equality at the end to avoid a CPU
|
||||
# sync point after each iteration
|
||||
ref_losses: list[torch.Tensor] = []
|
||||
@ -629,7 +654,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
optim.step()
|
||||
# Sleep after the optimizer step to allow CPU to run ahead into the
|
||||
# next iteration's forward, exercising the post-optim stream sync
|
||||
torch.cuda._sleep(int(25 * get_cycles_per_ms()))
|
||||
torch.get_device_module(device_type)._sleep(int(25 * get_cycles_per_ms()))
|
||||
for ref_loss, loss in zip(ref_losses, losses):
|
||||
self.assertEqual(ref_loss, loss)
|
||||
|
||||
@ -639,7 +664,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
|
||||
def world_size(self) -> int:
|
||||
# Since these tests run with a larger transformer model, they may see
|
||||
# some numeric drift with >2 GPUs
|
||||
return min(torch.cuda.device_count(), 2)
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@compiled_fsdp_test(compile_compute_on_module=Transformer)
|
||||
@ -669,7 +694,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
|
||||
return
|
||||
torch.manual_seed(42)
|
||||
vocab_size = 1024
|
||||
with torch.device(torch.device("cuda")):
|
||||
with torch.device(device_type):
|
||||
model_args = ModelArgs(
|
||||
n_layers=3,
|
||||
n_heads=4,
|
||||
@ -683,7 +708,10 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
|
||||
weight_tying=module_grouping != "mem_eff",
|
||||
)
|
||||
model = Transformer(model_args)
|
||||
ref_model = replicate(copy.deepcopy(model), device_ids=[self.rank])
|
||||
ref_model = replicate(
|
||||
copy.deepcopy(model),
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
# Apply activation checkpointing
|
||||
@ -723,7 +751,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
|
||||
torch.manual_seed(42 + self.rank)
|
||||
# Reuse the same input across iterations to avoid loss explosion from
|
||||
# trying to learn from random inputs
|
||||
inp = torch.randint(0, vocab_size, (3, 64), device="cuda")
|
||||
inp = torch.randint(0, vocab_size, (3, 64), device=device_type.type)
|
||||
check_sharded_parity(
|
||||
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
|
||||
)
|
||||
@ -750,14 +778,14 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
|
||||
class TestFullyShardShardPlacementFnMultiProcess(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(8, torch.cuda.device_count())
|
||||
return min(8, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_parity_shard_placement_fn_shard_largest_dim(self):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(n_layers=3, dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
||||
@ -773,7 +801,7 @@ class TestFullyShardShardPlacementFnMultiProcess(FSDPTest):
|
||||
self.assertEqual(full_param, ref_param)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
for iter_idx in range(5):
|
||||
ref_loss = ref_model(inp).sum()
|
||||
loss = model(inp).sum()
|
||||
@ -800,7 +828,7 @@ class TestFullyShardShardPlacementFnMultiThread(FSDPTestMultiThread):
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_shard_placement_fn_contiguous_params_grads(self):
|
||||
dim = 4
|
||||
model = MLP(dim=dim)
|
||||
@ -825,7 +853,7 @@ class TestFullyShardShardPlacementFnMultiThread(FSDPTestMultiThread):
|
||||
self.assertTrue(param.is_contiguous())
|
||||
self.assertTrue(param.to_local().is_contiguous())
|
||||
|
||||
inp = torch.randn((2, dim), device="cuda")
|
||||
inp = torch.randn((2, dim), device=device_type.type)
|
||||
model(inp).sum().backward()
|
||||
|
||||
for param in model.parameters():
|
||||
@ -838,7 +866,7 @@ class TestFullyShardShardPlacementFnMultiThread(FSDPTestMultiThread):
|
||||
class TestFullyShardSharedParams(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_parity_with_shared_params(self):
|
||||
@ -858,8 +886,11 @@ class TestFullyShardSharedParams(FSDPTest):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(n_layers=3, dropout_p=0.0, weight_tying=True)
|
||||
model = Transformer(model_args)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
replicate(ref_model, device_ids=[self.rank])
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate(
|
||||
ref_model,
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
@ -871,7 +902,9 @@ class TestFullyShardSharedParams(FSDPTest):
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
inp = torch.randint(
|
||||
0, model_args.vocab_size, (2, 16), device=device_type.type
|
||||
)
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
@ -884,7 +917,7 @@ class TestFullyShardSharedParams(FSDPTest):
|
||||
class TestFullyShardGradientAccumulation(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_gradient_accumulation(self):
|
||||
@ -892,12 +925,14 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||
Tests gradient accumulation with/without gradient reduction and
|
||||
with/without resharding after backward.
|
||||
"""
|
||||
meshes = [init_device_mesh("cuda", (self.world_size,))] # always test FSDP
|
||||
meshes = [
|
||||
init_device_mesh(device_type.type, (self.world_size,))
|
||||
] # always test FSDP
|
||||
if self.world_size == 4: # test HSDP too if enough GPUs
|
||||
shard_size, replicate_size = 2, 2
|
||||
meshes.append(
|
||||
init_device_mesh(
|
||||
"cuda",
|
||||
device_type.type,
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard"),
|
||||
)
|
||||
@ -951,7 +986,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||
modules = [nn.Linear(lin_dim, lin_dim)]
|
||||
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
|
||||
model = nn.Sequential(*modules)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
fully_shard_fn = functools.partial(
|
||||
fully_shard,
|
||||
mesh=mesh,
|
||||
@ -994,7 +1029,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||
for microbatch_idx in range(num_microbatches):
|
||||
is_last_microbatch = microbatch_idx == num_microbatches - 1
|
||||
set_backward_flags(model, is_last_microbatch)
|
||||
inp = torch.randn(batch_size, lin_dim, device="cuda")
|
||||
inp = torch.randn(batch_size, lin_dim, device=device_type.type)
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
with CommDebugMode() as comm_mode:
|
||||
@ -1083,7 +1118,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
@ -1096,7 +1131,10 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inps = [
|
||||
torch.randint(
|
||||
0, model_args.vocab_size, (local_batch_size, 16), device="cuda"
|
||||
0,
|
||||
model_args.vocab_size,
|
||||
(local_batch_size, 16),
|
||||
device=device_type.type,
|
||||
)
|
||||
for _ in range(num_microbatches)
|
||||
]
|
||||
@ -1136,14 +1174,14 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
||||
class TestFullyShardNDTraining(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(8, torch.cuda.device_count())
|
||||
return min(8, torch.get_device_module(device_type).device_count())
|
||||
|
||||
def init_global_mesh(self) -> DeviceMesh:
|
||||
# Prefer to test with >=8 GPUs, but for 2 GPUs, use 2-way TP
|
||||
dp_size = 2 if self.world_size > 2 else 1
|
||||
pp_size = 2 if self.world_size > 4 else 1
|
||||
return init_device_mesh(
|
||||
"cuda",
|
||||
device_type.type,
|
||||
(pp_size, dp_size, self.world_size // (dp_size * pp_size)),
|
||||
mesh_dim_names=("pp", "dp", "tp"),
|
||||
)
|
||||
@ -1179,8 +1217,12 @@ class TestFullyShardNDTraining(FSDPTest):
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = MLPStack(mlp_dim)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate(
|
||||
ref_model,
|
||||
device_ids=[self.rank],
|
||||
process_group=dp_pg,
|
||||
)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
|
||||
model.parallelize(
|
||||
tp_mesh,
|
||||
@ -1191,7 +1233,7 @@ class TestFullyShardNDTraining(FSDPTest):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
|
||||
|
||||
torch.manual_seed(42 + dp_pg.rank() + 1)
|
||||
device = torch.device("cuda")
|
||||
device = device_type
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, mlp_dim), device=device)
|
||||
losses: list[torch.Tensor] = []
|
||||
@ -1212,11 +1254,11 @@ class TestFullyShardNDTraining(FSDPTest):
|
||||
class TestFullyShardHSDP3DTraining(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(8, torch.cuda.device_count())
|
||||
return min(8, torch.get_device_module(device_type).device_count())
|
||||
|
||||
def init_global_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh(
|
||||
"cuda",
|
||||
device_type.type,
|
||||
(2, 2, 2),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
|
||||
)
|
||||
@ -1248,8 +1290,12 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = MLPStack(mlp_dim)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate(
|
||||
ref_model,
|
||||
device_ids=[self.rank],
|
||||
process_group=dp_pg,
|
||||
)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
|
||||
model.parallelize(
|
||||
tp_mesh,
|
||||
@ -1266,7 +1312,7 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
|
||||
|
||||
torch.manual_seed(42 + dp_pg.rank() + 1)
|
||||
device = torch.device("cuda")
|
||||
device = device_type
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, mlp_dim), device=device)
|
||||
losses: list[torch.Tensor] = []
|
||||
@ -1289,14 +1335,14 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
|
||||
class TestFullyShardHSDPTraining(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
return min(4, torch.get_device_module(device_type).device_count())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_parity_hsdp(self):
|
||||
shard_size = 2 if self.world_size > 2 else 1
|
||||
replicate_size = self.world_size // shard_size
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda",
|
||||
device_type.type,
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard"),
|
||||
)
|
||||
@ -1325,8 +1371,11 @@ class TestFullyShardHSDPTraining(FSDPTest):
|
||||
MLP(mlp_dim),
|
||||
MLP(mlp_dim, dim_multiplier=3),
|
||||
)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
replicate(ref_model, device_ids=[self.rank])
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate(
|
||||
ref_model,
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for mlp in model:
|
||||
if use_activation_checkpointing:
|
||||
@ -1340,7 +1389,7 @@ class TestFullyShardHSDPTraining(FSDPTest):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
device = torch.device("cuda")
|
||||
device = device_type
|
||||
num_microbatches = 3
|
||||
for iter_idx in range(5):
|
||||
for microbatch_idx in range(num_microbatches):
|
||||
@ -1363,7 +1412,7 @@ class TestFullyShardHSDPTraining(FSDPTest):
|
||||
class TestFullyShardCustomForwardMethod(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(torch.cuda.device_count(), 2)
|
||||
return min(torch.get_device_module(device_type).device_count(), 2)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_register_fsdp_forward_method(self):
|
||||
@ -1392,14 +1441,14 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = Model()
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
fully_shard(model.vit)
|
||||
fully_shard(model.projector)
|
||||
fully_shard(model)
|
||||
register_fsdp_forward_method(model.vit, "forward_features")
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn(4, 3, 224, 224, device="cuda")
|
||||
inp = torch.randn(4, 3, 224, 224, device=device_type.type)
|
||||
ref_loss = ref_model(inp).sum()
|
||||
loss = model(inp).sum()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
|
||||
@ -375,7 +375,9 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||
"Allow in graph produces an unserializable cache artifact"
|
||||
)
|
||||
|
||||
with inductor_config.patch("unsafe_marked_cacheable_functions", [fn_name]):
|
||||
with inductor_config.patch(
|
||||
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
|
||||
):
|
||||
fn(*args)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
|
||||
@ -390,6 +392,36 @@ class AOTAutogradCacheTests(InductorTestCase):
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||
|
||||
self._clear_dynamo_and_codecache()
|
||||
with inductor_config.patch(
|
||||
"unsafe_marked_cacheable_functions", {fn_name: "key2"}
|
||||
):
|
||||
fn(*args)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||
|
||||
self._clear_dynamo_and_codecache()
|
||||
|
||||
fn(*args)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||
|
||||
# On second try with same key, it should hit once more
|
||||
with inductor_config.patch(
|
||||
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
|
||||
):
|
||||
self._clear_dynamo_and_codecache()
|
||||
|
||||
fn(*args)
|
||||
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 3)
|
||||
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
|
||||
|
||||
@inductor_config.patch("fx_graph_remote_cache", False)
|
||||
@inductor_config.patch("fx_graph_cache", False)
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
|
||||
@ -1873,7 +1873,7 @@ def forward(self, x, y):
|
||||
return x + x
|
||||
|
||||
def false_fn(x):
|
||||
return x[:2]
|
||||
return x[:2].clone()
|
||||
|
||||
return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
|
||||
|
||||
@ -1883,7 +1883,7 @@ def forward(self, x, y):
|
||||
return x + x
|
||||
|
||||
def false_fn(x):
|
||||
return x[:2]
|
||||
return x[:2].clone()
|
||||
|
||||
return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
|
||||
|
||||
@ -1924,7 +1924,8 @@ def forward(self, l_x_):
|
||||
def forward(self, l_x_):
|
||||
l_x__1 = l_x_
|
||||
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
|
||||
return (getitem,)""",
|
||||
clone = getitem.clone(); getitem = None
|
||||
return (clone,)""",
|
||||
)
|
||||
# We could successfully export branches that return different sizes
|
||||
torch._dynamo.export(mod)(torch.randn(3, 2))
|
||||
@ -3302,7 +3303,12 @@ def forward(self, x):
|
||||
|
||||
def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
|
||||
def f_branch_return_multiple_tensors(pred, x, y):
|
||||
return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])
|
||||
return cond(
|
||||
pred,
|
||||
lambda x: (x.clone(), x.clone()),
|
||||
lambda x: (x.clone(), x.clone()),
|
||||
[y],
|
||||
)
|
||||
|
||||
example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
|
||||
gm, _ = torch._dynamo.export(
|
||||
@ -3324,10 +3330,10 @@ def forward(self, x):
|
||||
|
||||
def test_cond_raise_user_error_on_mismatch_return_length(self):
|
||||
def true_fn(x):
|
||||
return x
|
||||
return x.clone()
|
||||
|
||||
def false_fn(x):
|
||||
return (x, x)
|
||||
return (x.clone(), x.clone())
|
||||
|
||||
def f_mismatch_return_length(x):
|
||||
return cond(torch.tensor(100), true_fn, false_fn, [x])
|
||||
|
||||
@ -170,8 +170,8 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
in warning_message
|
||||
):
|
||||
break
|
||||
else:
|
||||
self.assertTrue(False, "Expected warning about lru_cache not found")
|
||||
else:
|
||||
self.assertTrue(False, "Expected warning about lru_cache not found")
|
||||
|
||||
@make_test
|
||||
def test_add(a, b):
|
||||
|
||||
@ -1791,7 +1791,13 @@ def forward(self, child : torch.Tensor):
|
||||
|
||||
def test_map_pytree_return(self):
|
||||
def _construct_pytree(a):
|
||||
return (a, [[[a]]], a, (a, (a,), a), {"a": a})
|
||||
return (
|
||||
a.clone(),
|
||||
[[[a.clone()]]],
|
||||
a.clone(),
|
||||
(a.clone(), (a.clone(),), a.clone()),
|
||||
{"a": a.clone()},
|
||||
)
|
||||
|
||||
def f(x):
|
||||
def inner_f(xs):
|
||||
@ -1823,7 +1829,14 @@ def forward(self, L_x_ : torch.Tensor):
|
||||
body_graph,
|
||||
"""\
|
||||
def forward(self, child : torch.Tensor):
|
||||
return (child, child, child, child, child, child, child)""",
|
||||
child_1 = child.clone()
|
||||
child_2 = child.clone()
|
||||
child_3 = child.clone()
|
||||
child_4 = child.clone()
|
||||
child_5 = child.clone()
|
||||
child_6 = child.clone()
|
||||
child_7 = child.clone(); child = None
|
||||
return (child_1, child_2, child_3, child_4, child_5, child_6, child_7)""",
|
||||
)
|
||||
|
||||
def test_map_kwargs(self):
|
||||
@ -6902,7 +6915,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test(pred, x):
|
||||
def true_fn(x):
|
||||
return x
|
||||
return x.clone()
|
||||
|
||||
def false_fn(x):
|
||||
return -x
|
||||
@ -6926,7 +6939,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test(pred, mode, x):
|
||||
def true_fn(x):
|
||||
return x
|
||||
return x.clone()
|
||||
|
||||
def false_fn(x):
|
||||
return -x
|
||||
|
||||
@ -5931,7 +5931,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
||||
def true_fn(x):
|
||||
return x
|
||||
return x.clone()
|
||||
|
||||
def false_fn(x):
|
||||
return x.sin()
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from unittest import expectedFailure
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -394,7 +393,6 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self.assertEqual(counter.frame_count, 2) # not three or four!
|
||||
|
||||
@expectedFailure # TODO(laithsakka, pianpwk): handle guard_or_false before oblivious hint fallback
|
||||
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="oblivious")
|
||||
def test_automatic_dynamic_shapes_mark_as_oblivious(self):
|
||||
counter = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@ -9,7 +9,7 @@ append_cxx_flag_if_supported("-Wno-unused-private-field" CMAKE_CXX_FLAGS)
|
||||
|
||||
# Generate unboxing kernels
|
||||
set(GEN_COMMAND
|
||||
"${Python_EXECUTABLE}" -m torchgen.gen_executorch
|
||||
Python::Interpreter -m torchgen.gen_executorch
|
||||
--source-path=${TEST_ROOT}
|
||||
--install-dir=${OUTPUT_DIRECTORY}
|
||||
--tags-path=${TORCH_ROOT}/aten/src/ATen/native/tags.yaml
|
||||
@ -58,11 +58,7 @@ add_executable(test_edge_op_registration
|
||||
|
||||
target_compile_definitions(test_edge_op_registration PRIVATE USE_GTEST)
|
||||
|
||||
set(TEST_DEPENDENCIES gtest unbox_lib)
|
||||
|
||||
target_link_libraries(test_edge_op_registration PRIVATE
|
||||
${TEST_DEPENDENCIES}
|
||||
)
|
||||
target_link_libraries(test_edge_op_registration PRIVATE gtest_main unbox_lib)
|
||||
if((CMAKE_CXX_COMPILER_ID MATCHES "AppleClang") OR (APPLE AND CMAKE_CXX_COMPILER_ID MATCHES "Clang"))
|
||||
target_link_options(test_edge_op_registration PRIVATE
|
||||
"-Wl,-force_load,$<TARGET_FILE:unbox_lib>"
|
||||
|
||||
@ -323,7 +323,7 @@ class TestDraftExport(TestCase):
|
||||
self.assertEqual(
|
||||
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
|
||||
)
|
||||
self.assertEqual(report.failures[0].data["expr"], "Eq(9380*u1, 0)")
|
||||
self.assertEqual(report.failures[0].data["expr"], "Eq(Mod(10, 2*u1), 0)")
|
||||
|
||||
def test_dedup_data_dependent_failure(self):
|
||||
class M(torch.nn.Module):
|
||||
|
||||
@ -7604,20 +7604,25 @@ def forward(self, b_a_buffer, x):
|
||||
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
|
||||
|
||||
@requires_cuda
|
||||
@testing.expectedFailureCppRuntime
|
||||
def test_export_associative_scan_lifted_buffers(self):
|
||||
device = torch.device("cuda")
|
||||
combine_mode = "pointwise"
|
||||
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.ones(3, 2, device=device))
|
||||
|
||||
def forward(self):
|
||||
return self.buffer.cos()
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"buf", torch.ones(3, 2, device=device), persistent=False
|
||||
)
|
||||
self.a = A()
|
||||
|
||||
def combine_fn(self, x, y):
|
||||
return x + y * self.buf
|
||||
return (x + y) * self.a()
|
||||
|
||||
def forward(self, x):
|
||||
return associative_scan(
|
||||
|
||||
@ -4572,17 +4572,17 @@ class <lambda>(torch.nn.Module):
|
||||
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
|
||||
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
|
||||
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
|
||||
|
||||
body_graph_1 = self.body_graph_1
|
||||
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
|
||||
getitem_1: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
|
||||
getitem_5: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
|
||||
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_5); getitem_5 = None
|
||||
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
|
||||
return (add_1,)
|
||||
@ -4635,9 +4635,9 @@ class <lambda>(torch.nn.Module):
|
||||
|
||||
body_graph_0 = self.body_graph_0
|
||||
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
|
||||
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||
return (add,)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user