mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 02:04:53 +08:00
Compare commits
105 Commits
exec
...
traceable_
| Author | SHA1 | Date | |
|---|---|---|---|
| d0e00d5448 | |||
| 25229787d6 | |||
| e84cf805d2 | |||
| 254487f288 | |||
| 73340f0909 | |||
| 8c2542623b | |||
| 734891ac22 | |||
| ddb95dbb0d | |||
| 832fc35211 | |||
| 65286883d4 | |||
| fc5b0ff2d7 | |||
| b2a9b8d485 | |||
| e0aa992d73 | |||
| 2bb8ee602b | |||
| 7178b4e987 | |||
| ea47d542ca | |||
| 54b0006cb2 | |||
| 799acd31b4 | |||
| 0d25f096c1 | |||
| 6d2b3c90f1 | |||
| ad2593cb86 | |||
| 19f3abcde4 | |||
| 609ffaf717 | |||
| d8db074988 | |||
| 859fa183fe | |||
| a2b1673dfb | |||
| 9d06e3783d | |||
| a6ac6447b5 | |||
| 571a0db132 | |||
| 277f2914a5 | |||
| fca408fa29 | |||
| 73f5d2b787 | |||
| df94d57c0a | |||
| b5d541609d | |||
| bafd68b4fc | |||
| 0707811286 | |||
| 0fc603ece4 | |||
| 1b92bdd0ea | |||
| 236fbcbdf4 | |||
| 7d33ff59ba | |||
| ffb50fb691 | |||
| 3397d5ef90 | |||
| 118f9ceb7c | |||
| e49525275d | |||
| 7fac03aee9 | |||
| 50567f7081 | |||
| d3e8b8bf47 | |||
| ba92f5277f | |||
| 3a185778ed | |||
| a584b2a389 | |||
| fcf2a1378b | |||
| 2f88597aad | |||
| 1f0a68b572 | |||
| acefc5c016 | |||
| eb9f4da11e | |||
| 8771e3429c | |||
| ed5b8432cd | |||
| df85f34a14 | |||
| 4bc90185fb | |||
| eda375a490 | |||
| 2458f79f83 | |||
| b0d2fe6299 | |||
| 5ffb032be6 | |||
| 35c78668b4 | |||
| 99f042d336 | |||
| 670b94c9c8 | |||
| c5e0b84484 | |||
| cb5e9183c6 | |||
| ac5f565fa7 | |||
| d9c294c672 | |||
| a0e1e20c41 | |||
| 3b798df853 | |||
| cec31050b4 | |||
| e47603a549 | |||
| 2227da4431 | |||
| 4cc3fb5ee2 | |||
| 5dc4f652bc | |||
| 44722c6b10 | |||
| 1babeddbbf | |||
| 5bc9835d64 | |||
| 9a7e2519d3 | |||
| fe8558b7aa | |||
| abde6cab4c | |||
| 04a5d3228e | |||
| 44483972bd | |||
| bdffd9f0c6 | |||
| 1a527915a6 | |||
| d77a1aaa86 | |||
| 1877b7896c | |||
| 77830d509f | |||
| 84c86e56bd | |||
| 4e03263224 | |||
| 26e374e3ca | |||
| 9818283da1 | |||
| ec616da518 | |||
| 108318ad10 | |||
| 4817180601 | |||
| 22d258427b | |||
| e6d4451ae8 | |||
| f2805a0408 | |||
| 3dd5f0ecbb | |||
| 304c934572 | |||
| 6e43897912 | |||
| 60baeee59f | |||
| e3a39d49a0 |
@ -1 +1 @@
|
||||
d4b3e5cc607e97afdba79dc90f8ef968142f347c
|
||||
172574a6be5910a4609e4ed1bef2b6b8475ddb3d
|
||||
|
||||
@ -37,6 +37,9 @@ install_conda_dependencies() {
|
||||
|
||||
install_pip_dependencies() {
|
||||
pushd executorch/.ci/docker
|
||||
# Install PyTorch CPU build beforehand to avoid installing the much bigger CUDA
|
||||
# binaries later, ExecuTorch only needs CPU
|
||||
pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
# Install all Python dependencies
|
||||
pip_install -r requirements-ci.txt
|
||||
popd
|
||||
@ -44,13 +47,14 @@ install_pip_dependencies() {
|
||||
|
||||
setup_executorch() {
|
||||
pushd executorch
|
||||
source .ci/scripts/utils.sh
|
||||
# Setup swiftshader and Vulkan SDK which are required to build the Vulkan delegate
|
||||
as_jenkins bash .ci/scripts/setup-vulkan-linux-deps.sh
|
||||
|
||||
install_flatc_from_source
|
||||
pip_install .
|
||||
export PYTHON_EXECUTABLE=python
|
||||
export EXECUTORCH_BUILD_PYBIND=ON
|
||||
export CMAKE_ARGS="-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON"
|
||||
|
||||
# Make sure that all the newly generate files are owned by Jenkins
|
||||
chown -R jenkins .
|
||||
as_jenkins .ci/scripts/setup-linux.sh cmake
|
||||
popd
|
||||
}
|
||||
|
||||
|
||||
@ -284,12 +284,26 @@ else
|
||||
# Which should be backward compatible with Numpy-1.X
|
||||
python -mpip install --pre numpy==2.0.0rc1
|
||||
fi
|
||||
WERROR=1 python setup.py bdist_wheel
|
||||
|
||||
WERROR=1 python setup.py clean
|
||||
|
||||
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
|
||||
BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 python setup.py bdist_wheel
|
||||
BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 python setup.py bdist_wheel --cmake
|
||||
else
|
||||
WERROR=1 python setup.py bdist_wheel
|
||||
fi
|
||||
else
|
||||
python setup.py clean
|
||||
if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then
|
||||
source .ci/pytorch/install_cache_xla.sh
|
||||
fi
|
||||
python setup.py bdist_wheel
|
||||
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
|
||||
echo "USE_SPLIT_BUILD cannot be used with xla or rocm"
|
||||
exit 1
|
||||
else
|
||||
python setup.py bdist_wheel
|
||||
fi
|
||||
fi
|
||||
pip_install_whl "$(echo dist/*.whl)"
|
||||
|
||||
@ -328,9 +342,10 @@ else
|
||||
CUSTOM_OP_TEST="$PWD/test/custom_operator"
|
||||
python --version
|
||||
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
|
||||
|
||||
mkdir -p "$CUSTOM_OP_BUILD"
|
||||
pushd "$CUSTOM_OP_BUILD"
|
||||
cmake "$CUSTOM_OP_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPython_EXECUTABLE="$(which python)" \
|
||||
cmake "$CUSTOM_OP_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch;$SITE_PACKAGES" -DPython_EXECUTABLE="$(which python)" \
|
||||
-DCMAKE_MODULE_PATH="$CUSTOM_TEST_MODULE_PATH" -DUSE_ROCM="$CUSTOM_TEST_USE_ROCM"
|
||||
make VERBOSE=1
|
||||
popd
|
||||
@ -343,7 +358,7 @@ else
|
||||
SITE_PACKAGES="$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')"
|
||||
mkdir -p "$JIT_HOOK_BUILD"
|
||||
pushd "$JIT_HOOK_BUILD"
|
||||
cmake "$JIT_HOOK_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPython_EXECUTABLE="$(which python)" \
|
||||
cmake "$JIT_HOOK_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch;$SITE_PACKAGES" -DPython_EXECUTABLE="$(which python)" \
|
||||
-DCMAKE_MODULE_PATH="$CUSTOM_TEST_MODULE_PATH" -DUSE_ROCM="$CUSTOM_TEST_USE_ROCM"
|
||||
make VERBOSE=1
|
||||
popd
|
||||
@ -355,7 +370,7 @@ else
|
||||
python --version
|
||||
mkdir -p "$CUSTOM_BACKEND_BUILD"
|
||||
pushd "$CUSTOM_BACKEND_BUILD"
|
||||
cmake "$CUSTOM_BACKEND_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch" -DPython_EXECUTABLE="$(which python)" \
|
||||
cmake "$CUSTOM_BACKEND_TEST" -DCMAKE_PREFIX_PATH="$SITE_PACKAGES/torch;$SITE_PACKAGES" -DPython_EXECUTABLE="$(which python)" \
|
||||
-DCMAKE_MODULE_PATH="$CUSTOM_TEST_MODULE_PATH" -DUSE_ROCM="$CUSTOM_TEST_USE_ROCM"
|
||||
make VERBOSE=1
|
||||
popd
|
||||
|
||||
@ -56,9 +56,29 @@ function assert_git_not_dirty() {
|
||||
function pip_install_whl() {
|
||||
# This is used to install PyTorch and other build artifacts wheel locally
|
||||
# without using any network connection
|
||||
python3 -mpip install --no-index --no-deps "$@"
|
||||
|
||||
# Convert the input arguments into an array
|
||||
local args=("$@")
|
||||
|
||||
# Check if the first argument contains multiple paths separated by spaces
|
||||
if [[ "${args[0]}" == *" "* ]]; then
|
||||
# Split the string by spaces into an array
|
||||
IFS=' ' read -r -a paths <<< "${args[0]}"
|
||||
# Loop through each path and install individually
|
||||
for path in "${paths[@]}"; do
|
||||
echo "Installing $path"
|
||||
python3 -mpip install --no-index --no-deps "$path"
|
||||
done
|
||||
else
|
||||
# Loop through each argument and install individually
|
||||
for path in "${args[@]}"; do
|
||||
echo "Installing $path"
|
||||
python3 -mpip install --no-index --no-deps "$path"
|
||||
done
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
function pip_install() {
|
||||
# retry 3 times
|
||||
# old versions of pip don't have the "--progress-bar" flag
|
||||
|
||||
@ -289,6 +289,9 @@ test_python_shard() {
|
||||
|
||||
# Bare --include flag is not supported and quoting for lint ends up with flag not being interpreted correctly
|
||||
# shellcheck disable=SC2086
|
||||
|
||||
# modify LD_LIBRARY_PATH to ensure it has the conda env.
|
||||
# This set of tests has been shown to be buggy without it for the split-build
|
||||
time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests $INCLUDE_CLAUSE --shard "$1" "$NUM_TEST_SHARDS" --verbose $PYTHON_TEST_EXTRA_OPTION
|
||||
|
||||
assert_git_not_dirty
|
||||
@ -1174,15 +1177,21 @@ test_executorch() {
|
||||
|
||||
pushd /executorch
|
||||
|
||||
# NB: We need to build ExecuTorch runner here and not inside the Docker image
|
||||
# because it depends on PyTorch
|
||||
export PYTHON_EXECUTABLE=python
|
||||
export EXECUTORCH_BUILD_PYBIND=ON
|
||||
export CMAKE_ARGS="-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON"
|
||||
|
||||
# NB: We need to rebuild ExecuTorch runner here because it depends on PyTorch
|
||||
# from the PR
|
||||
# shellcheck disable=SC1091
|
||||
source .ci/scripts/utils.sh
|
||||
build_executorch_runner "cmake"
|
||||
source .ci/scripts/setup-linux.sh cmake
|
||||
|
||||
echo "Run ExecuTorch unit tests"
|
||||
pytest -v -n auto
|
||||
# shellcheck disable=SC1091
|
||||
LLVM_PROFDATA=llvm-profdata-12 LLVM_COV=llvm-cov-12 bash test/run_oss_cpp_tests.sh
|
||||
|
||||
echo "Run ExecuTorch regression tests for some models"
|
||||
# NB: This is a sample model, more can be added here
|
||||
export PYTHON_EXECUTABLE=python
|
||||
# TODO(huydhn): Add more coverage here using ExecuTorch's gather models script
|
||||
# shellcheck disable=SC1091
|
||||
source .ci/scripts/test.sh mv3 cmake xnnpack-quantization-delegation ''
|
||||
|
||||
@ -33,9 +33,9 @@ if [[ -z "$DOCKER_IMAGE" ]]; then
|
||||
if [[ "$PACKAGE_TYPE" == conda ]]; then
|
||||
export DOCKER_IMAGE="pytorch/conda-cuda"
|
||||
elif [[ "$DESIRED_CUDA" == cpu ]]; then
|
||||
export DOCKER_IMAGE="pytorch/manylinux-cpu"
|
||||
export DOCKER_IMAGE="pytorch/manylinux:cpu"
|
||||
else
|
||||
export DOCKER_IMAGE="pytorch/manylinux-cuda${DESIRED_CUDA:2}"
|
||||
export DOCKER_IMAGE="pytorch/manylinux-builder:${DESIRED_CUDA:2}"
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -75,9 +75,9 @@ export PYTORCH_BUILD_NUMBER=1
|
||||
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
|
||||
|
||||
# Here PYTORCH_EXTRA_INSTALL_REQUIREMENTS is already set for the all the wheel builds hence append TRITON_CONSTRAINT
|
||||
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.13'"
|
||||
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then
|
||||
# Only linux Python < 3.13 are supported wheels for triton
|
||||
TRITON_CONSTRAINT="platform_system == 'Linux' and platform_machine == 'x86_64' and python_version < '3.13'"
|
||||
TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
|
||||
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
|
||||
TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
|
||||
@ -87,11 +87,11 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:
|
||||
fi
|
||||
|
||||
# Set triton via PYTORCH_EXTRA_INSTALL_REQUIREMENTS for triton rocm package
|
||||
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" && "$DESIRED_PYTHON" != "3.12" ]]; then
|
||||
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}"
|
||||
if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" ]]; then
|
||||
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
|
||||
if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
|
||||
TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt)
|
||||
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}"
|
||||
TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
|
||||
fi
|
||||
if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then
|
||||
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${TRITON_REQUIREMENT}"
|
||||
|
||||
21
.github/actions/linux-build/action.yml
vendored
21
.github/actions/linux-build/action.yml
vendored
@ -52,6 +52,13 @@ inputs:
|
||||
description: Hugging Face Hub token
|
||||
required: false
|
||||
default: ""
|
||||
use_split_build:
|
||||
description: |
|
||||
[Experimental] Build a libtorch only wheel and build pytorch such that
|
||||
are built from the libtorch wheel.
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
outputs:
|
||||
docker-image:
|
||||
value: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
@ -144,6 +151,7 @@ runs:
|
||||
DEBUG: ${{ inputs.build-with-debug == 'true' && '1' || '0' }}
|
||||
OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ inputs.HUGGING_FACE_HUB_TOKEN }}
|
||||
USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
|
||||
shell: bash
|
||||
run: |
|
||||
# detached container should get cleaned up by teardown_ec2_linux
|
||||
@ -163,6 +171,7 @@ runs:
|
||||
-e PR_LABELS \
|
||||
-e OUR_GITHUB_JOB_ID \
|
||||
-e HUGGING_FACE_HUB_TOKEN \
|
||||
-e USE_SPLIT_BUILD \
|
||||
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||
--security-opt seccomp=unconfined \
|
||||
--cap-add=SYS_PTRACE \
|
||||
@ -183,7 +192,7 @@ runs:
|
||||
|
||||
- name: Store PyTorch Build Artifacts on S3
|
||||
uses: seemethere/upload-artifact-s3@v5
|
||||
if: inputs.build-generates-artifacts == 'true' && steps.build.outcome != 'skipped'
|
||||
if: inputs.build-generates-artifacts == 'true' && steps.build.outcome != 'skipped' && inputs.use_split_build != 'true'
|
||||
with:
|
||||
name: ${{ inputs.build-environment }}
|
||||
retention-days: 14
|
||||
@ -191,6 +200,16 @@ runs:
|
||||
path: artifacts.zip
|
||||
s3-bucket: ${{ inputs.s3-bucket }}
|
||||
|
||||
- name: Store PyTorch Build Artifacts on S3 for split build
|
||||
uses: seemethere/upload-artifact-s3@v5
|
||||
if: inputs.build-generates-artifacts == 'true' && steps.build.outcome != 'skipped' && inputs.use_split_build == 'true'
|
||||
with:
|
||||
name: ${{ inputs.build-environment }}-experimental-split-build
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path: artifacts.zip
|
||||
s3-bucket: ${{ inputs.s3-bucket }}
|
||||
|
||||
- name: Upload sccache stats
|
||||
if: steps.build.outcome != 'skipped'
|
||||
uses: seemethere/upload-artifact-s3@v5
|
||||
|
||||
2
.github/ci_commit_pins/torchbench.txt
vendored
2
.github/ci_commit_pins/torchbench.txt
vendored
@ -1 +1 @@
|
||||
0dab1dd97709096e8129f8a08115ee83f64f2194
|
||||
23512dbebd44a11eb84afbf53c3c071dd105297e
|
||||
|
||||
114
.github/scripts/cherry_pick.py
vendored
114
.github/scripts/cherry_pick.py
vendored
@ -3,11 +3,11 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
from typing import Any, cast, Dict, List, Optional
|
||||
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from github_utils import gh_fetch_url, gh_post_pr_comment
|
||||
from github_utils import gh_fetch_url, gh_post_pr_comment, gh_query_issues_by_labels
|
||||
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
from trymerge import get_pr_commit_sha, GitHubPR
|
||||
@ -19,6 +19,7 @@ REQUIRES_ISSUE = {
|
||||
"critical",
|
||||
"fixnewfeature",
|
||||
}
|
||||
RELEASE_BRANCH_REGEX = re.compile(r"release/(?P<version>.+)")
|
||||
|
||||
|
||||
def parse_args() -> Any:
|
||||
@ -58,6 +59,33 @@ def get_merge_commit_sha(repo: GitRepo, pr: GitHubPR) -> Optional[str]:
|
||||
return commit_sha if pr.is_closed() else None
|
||||
|
||||
|
||||
def get_release_version(onto_branch: str) -> Optional[str]:
|
||||
"""
|
||||
Return the release version if the target branch is a release branch
|
||||
"""
|
||||
m = re.match(RELEASE_BRANCH_REGEX, onto_branch)
|
||||
return m.group("version") if m else ""
|
||||
|
||||
|
||||
def get_tracker_issues(
|
||||
org: str, project: str, onto_branch: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Find the tracker issue from the repo. The tracker issue needs to have the title
|
||||
like [VERSION] Release Tracker following the convention on PyTorch
|
||||
"""
|
||||
version = get_release_version(onto_branch)
|
||||
if not version:
|
||||
return []
|
||||
|
||||
tracker_issues = gh_query_issues_by_labels(org, project, labels=["release tracker"])
|
||||
if not tracker_issues:
|
||||
return []
|
||||
|
||||
# Figure out the tracker issue from the list by looking at the title
|
||||
return [issue for issue in tracker_issues if version in issue.get("title", "")]
|
||||
|
||||
|
||||
def cherry_pick(
|
||||
github_actor: str,
|
||||
repo: GitRepo,
|
||||
@ -77,17 +105,49 @@ def cherry_pick(
|
||||
)
|
||||
|
||||
try:
|
||||
org, project = repo.gh_owner_and_name()
|
||||
|
||||
cherry_pick_pr = ""
|
||||
if not dry_run:
|
||||
org, project = repo.gh_owner_and_name()
|
||||
cherry_pick_pr = submit_pr(repo, pr, cherry_pick_branch, onto_branch)
|
||||
|
||||
msg = f"The cherry pick PR is at {cherry_pick_pr}"
|
||||
if fixes:
|
||||
msg += f" and it is linked with issue {fixes}"
|
||||
elif classification in REQUIRES_ISSUE:
|
||||
msg += f" and it is recommended to link a {classification} cherry pick PR with an issue"
|
||||
tracker_issues_comments = []
|
||||
tracker_issues = get_tracker_issues(org, project, onto_branch)
|
||||
for issue in tracker_issues:
|
||||
issue_number = int(str(issue.get("number", "0")))
|
||||
if not issue_number:
|
||||
continue
|
||||
|
||||
post_comment(org, project, pr.pr_num, msg)
|
||||
res = cast(
|
||||
Dict[str, Any],
|
||||
post_tracker_issue_comment(
|
||||
org,
|
||||
project,
|
||||
issue_number,
|
||||
pr.pr_num,
|
||||
cherry_pick_pr,
|
||||
classification,
|
||||
fixes,
|
||||
dry_run,
|
||||
),
|
||||
)
|
||||
|
||||
comment_url = res.get("html_url", "")
|
||||
if comment_url:
|
||||
tracker_issues_comments.append(comment_url)
|
||||
|
||||
msg = f"The cherry pick PR is at {cherry_pick_pr}"
|
||||
if fixes:
|
||||
msg += f" and it is linked with issue {fixes}."
|
||||
elif classification in REQUIRES_ISSUE:
|
||||
msg += f" and it is recommended to link a {classification} cherry pick PR with an issue."
|
||||
|
||||
if tracker_issues_comments:
|
||||
msg += " The following tracker issues are updated:\n"
|
||||
for tracker_issues_comment in tracker_issues_comments:
|
||||
msg += f"* {tracker_issues_comment}\n"
|
||||
|
||||
post_pr_comment(org, project, pr.pr_num, msg, dry_run)
|
||||
|
||||
finally:
|
||||
if current_branch:
|
||||
@ -159,7 +219,9 @@ def submit_pr(
|
||||
raise RuntimeError(msg) from error
|
||||
|
||||
|
||||
def post_comment(org: str, project: str, pr_num: int, msg: str) -> None:
|
||||
def post_pr_comment(
|
||||
org: str, project: str, pr_num: int, msg: str, dry_run: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Post a comment on the PR itself to point to the cherry picking PR when success
|
||||
or print the error when failure
|
||||
@ -182,7 +244,35 @@ def post_comment(org: str, project: str, pr_num: int, msg: str) -> None:
|
||||
comment = "\n".join(
|
||||
(f"### Cherry picking #{pr_num}", f"{msg}", "", f"{internal_debugging}")
|
||||
)
|
||||
gh_post_pr_comment(org, project, pr_num, comment)
|
||||
return gh_post_pr_comment(org, project, pr_num, comment, dry_run)
|
||||
|
||||
|
||||
def post_tracker_issue_comment(
|
||||
org: str,
|
||||
project: str,
|
||||
issue_num: int,
|
||||
pr_num: int,
|
||||
cherry_pick_pr: str,
|
||||
classification: str,
|
||||
fixes: str,
|
||||
dry_run: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Post a comment on the tracker issue (if any) to record the cherry pick
|
||||
"""
|
||||
comment = "\n".join(
|
||||
(
|
||||
"Link to landed trunk PR (if applicable):",
|
||||
f"* https://github.com/{org}/{project}/pull/{pr_num}",
|
||||
"",
|
||||
"Link to release branch PR:",
|
||||
f"* {cherry_pick_pr}",
|
||||
"",
|
||||
"Criteria Category:",
|
||||
" - ".join((classification.capitalize(), fixes.capitalize())),
|
||||
)
|
||||
)
|
||||
return gh_post_pr_comment(org, project, issue_num, comment, dry_run)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@ -214,7 +304,7 @@ def main() -> None:
|
||||
|
||||
except RuntimeError as error:
|
||||
if not args.dry_run:
|
||||
post_comment(org, project, pr_num, str(error))
|
||||
post_pr_comment(org, project, pr_num, str(error))
|
||||
else:
|
||||
raise error
|
||||
|
||||
|
||||
@ -347,10 +347,6 @@ def generate_wheels_matrix(
|
||||
for python_version in python_versions:
|
||||
for arch_version in arches:
|
||||
gpu_arch_type = arch_type(arch_version)
|
||||
# Disable py3.12 builds for ROCm because of triton dependency
|
||||
# on llnl-hatchet, which doesn't have py3.12 wheels available
|
||||
if gpu_arch_type == "rocm" and python_version == "3.12":
|
||||
continue
|
||||
gpu_arch_version = (
|
||||
""
|
||||
if arch_version == "cpu"
|
||||
|
||||
47
.github/scripts/get_workflow_type.py
vendored
47
.github/scripts/get_workflow_type.py
vendored
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
from typing import Any
|
||||
from typing import Any, Tuple
|
||||
|
||||
from github import Auth, Github
|
||||
from github.Issue import Issue
|
||||
@ -9,6 +9,8 @@ from github.Issue import Issue
|
||||
WORKFLOW_LABEL_META = "" # use meta runners
|
||||
WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation
|
||||
LABEL_TYPE_KEY = "label_type"
|
||||
MESSAGE_KEY = "message"
|
||||
MESSAGE = "" # Debug message to return to the caller
|
||||
|
||||
|
||||
def parse_args() -> Any:
|
||||
@ -48,45 +50,50 @@ def is_exception_branch(branch: str) -> bool:
|
||||
return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"}
|
||||
|
||||
|
||||
def get_workflow_type(issue: Issue, username: str) -> str:
|
||||
def get_workflow_type(issue: Issue, username: str) -> Tuple[str, str]:
|
||||
try:
|
||||
user_list = issue.get_comments()[0].body.split()
|
||||
|
||||
if user_list[0] == "!":
|
||||
print("LF Workflows are disabled for everyone. Using meta runners.")
|
||||
return WORKFLOW_LABEL_META
|
||||
MESSAGE = "LF Workflows are disabled for everyone. Using meta runners."
|
||||
return WORKFLOW_LABEL_META, MESSAGE
|
||||
elif user_list[0] == "*":
|
||||
print("LF Workflows are enabled for everyone. Using LF runners.")
|
||||
return WORKFLOW_LABEL_LF
|
||||
MESSAGE = "LF Workflows are enabled for everyone. Using LF runners."
|
||||
return WORKFLOW_LABEL_LF, MESSAGE
|
||||
elif username in user_list:
|
||||
print(f"LF Workflows are enabled for {username}. Using LF runners.")
|
||||
return WORKFLOW_LABEL_LF
|
||||
MESSAGE = f"LF Workflows are enabled for {username}. Using LF runners."
|
||||
return WORKFLOW_LABEL_LF, MESSAGE
|
||||
else:
|
||||
print(f"LF Workflows are disabled for {username}. Using meta runners.")
|
||||
return WORKFLOW_LABEL_META
|
||||
MESSAGE = f"LF Workflows are disabled for {username}. Using meta runners."
|
||||
return WORKFLOW_LABEL_META, MESSAGE
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
|
||||
)
|
||||
return WORKFLOW_LABEL_META
|
||||
MESSAGE = f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}"
|
||||
return WORKFLOW_LABEL_META, MESSAGE
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if is_exception_branch(args.github_branch):
|
||||
print(f"Exception branch: '{args.github_branch}', using meta runners")
|
||||
output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META}
|
||||
output = {
|
||||
LABEL_TYPE_KEY: WORKFLOW_LABEL_META,
|
||||
MESSAGE_KEY: f"Exception branch: '{args.github_branch}', using meta runners",
|
||||
}
|
||||
else:
|
||||
try:
|
||||
gh = get_gh_client(args.github_token)
|
||||
# The default issue we use - https://github.com/pytorch/test-infra/issues/5132
|
||||
issue = get_issue(gh, args.github_repo, args.github_issue)
|
||||
|
||||
output = {LABEL_TYPE_KEY: get_workflow_type(issue, args.github_user)}
|
||||
label_type, message = get_workflow_type(issue, args.github_user)
|
||||
output = {
|
||||
LABEL_TYPE_KEY: label_type,
|
||||
MESSAGE_KEY: message,
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Failed to get issue. Falling back to meta runners. Exception: {e}")
|
||||
output = {LABEL_TYPE_KEY: WORKFLOW_LABEL_META}
|
||||
output = {
|
||||
LABEL_TYPE_KEY: WORKFLOW_LABEL_META,
|
||||
MESSAGE_KEY: f"Failed to get issue. Falling back to meta runners. Exception: {e}",
|
||||
}
|
||||
|
||||
json_output = json.dumps(output)
|
||||
print(json_output)
|
||||
|
||||
9
.github/scripts/github_utils.py
vendored
9
.github/scripts/github_utils.py
vendored
@ -202,3 +202,12 @@ def gh_update_pr_state(org: str, repo: str, pr_num: int, state: str = "open") ->
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def gh_query_issues_by_labels(
|
||||
org: str, repo: str, labels: List[str], state: str = "open"
|
||||
) -> List[Dict[str, Any]]:
|
||||
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues"
|
||||
return gh_fetch_json(
|
||||
url, method="GET", params={"labels": ",".join(labels), "state": state}
|
||||
)
|
||||
|
||||
8
.github/workflows/_linux-build-label.yml
vendored
8
.github/workflows/_linux-build-label.yml
vendored
@ -56,6 +56,13 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
use_split_build:
|
||||
description: |
|
||||
[Experimental] Build a libtorch only wheel and build pytorch such that
|
||||
are built from the libtorch wheel.
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
secrets:
|
||||
HUGGING_FACE_HUB_TOKEN:
|
||||
required: false
|
||||
@ -107,3 +114,4 @@ jobs:
|
||||
aws-role-to-assume: ${{ inputs.aws-role-to-assume }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
use_split_build: ${{ inputs.use_split_build }}
|
||||
|
||||
10
.github/workflows/_linux-build.yml
vendored
10
.github/workflows/_linux-build.yml
vendored
@ -64,6 +64,14 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
use_split_build:
|
||||
description: |
|
||||
[Experimental] Build a libtorch only wheel and build pytorch such that
|
||||
are built from the libtorch wheel.
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
secrets:
|
||||
HUGGING_FACE_HUB_TOKEN:
|
||||
required: false
|
||||
@ -181,6 +189,7 @@ jobs:
|
||||
DEBUG: ${{ inputs.build-with-debug && '1' || '0' }}
|
||||
OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
|
||||
run: |
|
||||
# detached container should get cleaned up by teardown_ec2_linux
|
||||
container_name=$(docker run \
|
||||
@ -199,6 +208,7 @@ jobs:
|
||||
-e PR_LABELS \
|
||||
-e OUR_GITHUB_JOB_ID \
|
||||
-e HUGGING_FACE_HUB_TOKEN \
|
||||
-e USE_SPLIT_BUILD \
|
||||
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||
--security-opt seccomp=unconfined \
|
||||
--cap-add=SYS_PTRACE \
|
||||
|
||||
206
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
206
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
@ -2410,3 +2410,209 @@ jobs:
|
||||
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
|
||||
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
|
||||
manywheel-py3_12-rocm6_0-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
BUILDER_ROOT: /builder
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.0
|
||||
GPU_ARCH_VERSION: 6.0
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main
|
||||
DESIRED_PYTHON: "3.12"
|
||||
build_name: manywheel-py3_12-rocm6_0
|
||||
build_environment: linux-binary-manywheel
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_12-rocm6_0-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: manywheel-py3_12-rocm6_0-build
|
||||
runs-on: linux.rocm.gpu
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
BUILDER_ROOT: /builder
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.0
|
||||
GPU_ARCH_VERSION: 6.0
|
||||
GPU_ARCH_TYPE: rocm
|
||||
SKIP_ALL_TESTS: 1
|
||||
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main
|
||||
DESIRED_PYTHON: "3.12"
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
- uses: actions/download-artifact@v3
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: manywheel-py3_12-rocm6_0
|
||||
path: "${{ runner.temp }}/artifacts/"
|
||||
- name: Checkout PyTorch
|
||||
uses: malfet/checkout@silent-checkout
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
quiet-checkout: true
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Checkout pytorch/builder
|
||||
uses: malfet/checkout@silent-checkout
|
||||
with:
|
||||
ref: main
|
||||
submodules: recursive
|
||||
repository: pytorch/builder
|
||||
path: builder
|
||||
quiet-checkout: true
|
||||
- name: Clean pytorch/builder checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: builder
|
||||
- name: ROCm set GPU_FLAG
|
||||
run: |
|
||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
|
||||
- name: Pull Docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: pytorch/manylinux-builder:rocm6.0-main
|
||||
- name: Test Pytorch binary
|
||||
uses: ./pytorch/.github/actions/test-pytorch-binary
|
||||
- name: Teardown ROCm
|
||||
uses: ./.github/actions/teardown-rocm
|
||||
manywheel-py3_12-rocm6_0-upload: # Uploading
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
needs: manywheel-py3_12-rocm6_0-test
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
BUILDER_ROOT: /builder
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.0
|
||||
GPU_ARCH_VERSION: 6.0
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.0-main
|
||||
DESIRED_PYTHON: "3.12"
|
||||
build_name: manywheel-py3_12-rocm6_0
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
|
||||
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
|
||||
manywheel-py3_12-rocm6_1-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
uses: ./.github/workflows/_binary-build-linux.yml
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
BUILDER_ROOT: /builder
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.1
|
||||
GPU_ARCH_VERSION: 6.1
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main
|
||||
DESIRED_PYTHON: "3.12"
|
||||
build_name: manywheel-py3_12-rocm6_1
|
||||
build_environment: linux-binary-manywheel
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
manywheel-py3_12-rocm6_1-test: # Testing
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
needs: manywheel-py3_12-rocm6_1-build
|
||||
runs-on: linux.rocm.gpu
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
BUILDER_ROOT: /builder
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.1
|
||||
GPU_ARCH_VERSION: 6.1
|
||||
GPU_ARCH_TYPE: rocm
|
||||
SKIP_ALL_TESTS: 1
|
||||
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main
|
||||
DESIRED_PYTHON: "3.12"
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/setup-rocm
|
||||
- uses: actions/download-artifact@v3
|
||||
name: Download Build Artifacts
|
||||
with:
|
||||
name: manywheel-py3_12-rocm6_1
|
||||
path: "${{ runner.temp }}/artifacts/"
|
||||
- name: Checkout PyTorch
|
||||
uses: malfet/checkout@silent-checkout
|
||||
with:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
submodules: recursive
|
||||
path: pytorch
|
||||
quiet-checkout: true
|
||||
- name: Clean PyTorch checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: pytorch
|
||||
- name: Checkout pytorch/builder
|
||||
uses: malfet/checkout@silent-checkout
|
||||
with:
|
||||
ref: main
|
||||
submodules: recursive
|
||||
repository: pytorch/builder
|
||||
path: builder
|
||||
quiet-checkout: true
|
||||
- name: Clean pytorch/builder checkout
|
||||
run: |
|
||||
# Remove any artifacts from the previous checkouts
|
||||
git clean -fxd
|
||||
working-directory: builder
|
||||
- name: ROCm set GPU_FLAG
|
||||
run: |
|
||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
|
||||
- name: Pull Docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: pytorch/manylinux-builder:rocm6.1-main
|
||||
- name: Test Pytorch binary
|
||||
uses: ./pytorch/.github/actions/test-pytorch-binary
|
||||
- name: Teardown ROCm
|
||||
uses: ./.github/actions/teardown-rocm
|
||||
manywheel-py3_12-rocm6_1-upload: # Uploading
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
needs: manywheel-py3_12-rocm6_1-test
|
||||
with:
|
||||
PYTORCH_ROOT: /pytorch
|
||||
BUILDER_ROOT: /builder
|
||||
PACKAGE_TYPE: manywheel
|
||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||
# favor of GPU_ARCH_VERSION
|
||||
DESIRED_CUDA: rocm6.1
|
||||
GPU_ARCH_VERSION: 6.1
|
||||
GPU_ARCH_TYPE: rocm
|
||||
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main
|
||||
DESIRED_PYTHON: "3.12"
|
||||
build_name: manywheel-py3_12-rocm6_1
|
||||
secrets:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
|
||||
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
|
||||
uses: ./.github/workflows/_binary-upload.yml
|
||||
|
||||
6
.github/workflows/inductor-cu124.yml
vendored
6
.github/workflows/inductor-cu124.yml
vendored
@ -28,7 +28,8 @@ jobs:
|
||||
cuda-arch-list: '8.6'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" },
|
||||
{ config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
@ -95,7 +96,8 @@ jobs:
|
||||
cuda-arch-list: '8.6'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
]}
|
||||
|
||||
linux-focal-cuda12_4-py3_12-gcc9-inductor-test:
|
||||
|
||||
6
.github/workflows/inductor.yml
vendored
6
.github/workflows/inductor.yml
vendored
@ -48,7 +48,8 @@ jobs:
|
||||
cuda-arch-list: '8.6'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" },
|
||||
{ config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
@ -90,7 +91,8 @@ jobs:
|
||||
cuda-arch-list: '8.6'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
]}
|
||||
|
||||
linux-focal-cuda12_1-py3_12-gcc9-inductor-test:
|
||||
|
||||
8
.github/workflows/lint.yml
vendored
8
.github/workflows/lint.yml
vendored
@ -19,10 +19,10 @@ jobs:
|
||||
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
|
||||
with:
|
||||
timeout: 120
|
||||
runner: linux.2xlarge
|
||||
runner: lf.linux.2xlarge
|
||||
docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter
|
||||
# NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout
|
||||
# to run git rev-parse HEAD~:.ci/docker when a new image is needed
|
||||
# to run git rev-parse HEAD~:.ci/docker when a new image is needed.
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
@ -35,7 +35,7 @@ jobs:
|
||||
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
|
||||
with:
|
||||
timeout: 120
|
||||
runner: linux.2xlarge
|
||||
runner: lf.linux.2xlarge
|
||||
docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter
|
||||
# NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout
|
||||
# to run git rev-parse HEAD~:.ci/docker when a new image is needed
|
||||
@ -49,7 +49,7 @@ jobs:
|
||||
quick-checks:
|
||||
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
|
||||
with:
|
||||
runner: linux.2xlarge
|
||||
runner: lf.linux.2xlarge
|
||||
docker-image: pytorch-linux-focal-linter
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
|
||||
1
.github/workflows/periodic.yml
vendored
1
.github/workflows/periodic.yml
vendored
@ -73,7 +73,6 @@ jobs:
|
||||
{ config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "deploy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
|
||||
{ config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
|
||||
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
|
||||
28
.github/workflows/pull.yml
vendored
28
.github/workflows/pull.yml
vendored
@ -487,3 +487,31 @@ jobs:
|
||||
build-environment: linux-jammy-py3-clang12-executorch
|
||||
docker-image: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }}
|
||||
|
||||
linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build:
|
||||
name: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build
|
||||
uses: ./.github/workflows/_linux-build-label.yml
|
||||
with:
|
||||
use_split_build: true
|
||||
build-environment: linux-focal-cuda12.1-py3.10-gcc9
|
||||
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 2, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 3, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge.nvidia.gpu" },
|
||||
]}
|
||||
|
||||
linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build-test:
|
||||
name: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build
|
||||
- target-determination
|
||||
with:
|
||||
timeout-minutes: 360
|
||||
build-environment: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build
|
||||
docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build.outputs.test-matrix }}
|
||||
|
||||
6
.github/workflows/slow.yml
vendored
6
.github/workflows/slow.yml
vendored
@ -97,7 +97,8 @@ jobs:
|
||||
docker-image-name: pytorch-linux-focal-py3.8-clang10
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "slow", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
|
||||
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.2xlarge" },
|
||||
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.2xlarge" },
|
||||
]}
|
||||
|
||||
linux-focal-py3_8-clang10-test:
|
||||
@ -119,7 +120,8 @@ jobs:
|
||||
docker-image-name: pytorch-linux-focal-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "slow", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" },
|
||||
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" },
|
||||
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" },
|
||||
]}
|
||||
|
||||
linux-focal-rocm6_1-py3_8-test:
|
||||
|
||||
163
.lintrunner.toml
163
.lintrunner.toml
@ -1390,169 +1390,6 @@ exclude_patterns = [
|
||||
'torch/contrib/_tensorboard_vis.py',
|
||||
"torch/cuda/_gpu_trace.py",
|
||||
'torch/cuda/_memory_viz.py', # mypy: Value of type "object" is not indexable
|
||||
'torch/distributed/__init__.py',
|
||||
'torch/distributed/_composable_state.py',
|
||||
'torch/distributed/_shard/__init__.py',
|
||||
'torch/distributed/_shard/_utils.py',
|
||||
'torch/distributed/_shard/api.py',
|
||||
'torch/distributed/_shard/checkpoint/__init__.py',
|
||||
'torch/distributed/_shard/common_op_utils.py',
|
||||
'torch/distributed/_shard/metadata.py',
|
||||
'torch/distributed/_shard/op_registry_utils.py',
|
||||
'torch/distributed/_shard/sharded_optim/__init__.py',
|
||||
'torch/distributed/_shard/sharded_optim/api.py',
|
||||
'torch/distributed/_shard/sharded_tensor/__init__.py',
|
||||
'torch/distributed/_shard/sharded_tensor/_ops/__init__.py',
|
||||
'torch/distributed/_shard/sharded_tensor/_ops/_common.py',
|
||||
'torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py',
|
||||
'torch/distributed/_shard/sharded_tensor/_ops/init.py',
|
||||
'torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py',
|
||||
'torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py',
|
||||
'torch/distributed/_shard/sharded_tensor/api.py',
|
||||
'torch/distributed/_shard/sharded_tensor/logger.py',
|
||||
'torch/distributed/_shard/sharded_tensor/logging_handlers.py',
|
||||
'torch/distributed/_shard/sharded_tensor/metadata.py',
|
||||
'torch/distributed/_shard/sharded_tensor/reshard.py',
|
||||
'torch/distributed/_shard/sharded_tensor/shard.py',
|
||||
'torch/distributed/_shard/sharded_tensor/utils.py',
|
||||
'torch/distributed/_shard/sharder.py',
|
||||
'torch/distributed/_shard/sharding_plan/__init__.py',
|
||||
'torch/distributed/_shard/sharding_plan/api.py',
|
||||
'torch/distributed/_shard/sharding_spec/__init__.py',
|
||||
'torch/distributed/_shard/sharding_spec/_internals.py',
|
||||
'torch/distributed/_shard/sharding_spec/api.py',
|
||||
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py',
|
||||
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py',
|
||||
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py',
|
||||
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py',
|
||||
'torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py',
|
||||
'torch/distributed/_sharded_tensor/__init__.py',
|
||||
'torch/distributed/_sharding_spec/__init__.py',
|
||||
'torch/distributed/_tools/__init__.py',
|
||||
'torch/distributed/_tools/memory_tracker.py',
|
||||
'torch/distributed/algorithms/__init__.py',
|
||||
'torch/distributed/algorithms/_checkpoint/__init__.py',
|
||||
'torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py',
|
||||
'torch/distributed/algorithms/_comm_hooks/__init__.py',
|
||||
'torch/distributed/algorithms/_comm_hooks/default_hooks.py',
|
||||
'torch/distributed/algorithms/_optimizer_overlap/__init__.py',
|
||||
'torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py',
|
||||
'torch/distributed/algorithms/_quantization/__init__.py',
|
||||
'torch/distributed/algorithms/_quantization/quantization.py',
|
||||
'torch/distributed/algorithms/ddp_comm_hooks/__init__.py',
|
||||
'torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py',
|
||||
'torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py',
|
||||
'torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py',
|
||||
'torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py',
|
||||
'torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py',
|
||||
'torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py',
|
||||
'torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py',
|
||||
'torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py',
|
||||
'torch/distributed/algorithms/join.py',
|
||||
'torch/distributed/algorithms/model_averaging/__init__.py',
|
||||
'torch/distributed/algorithms/model_averaging/averagers.py',
|
||||
'torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py',
|
||||
'torch/distributed/algorithms/model_averaging/utils.py',
|
||||
'torch/distributed/argparse_util.py',
|
||||
'torch/distributed/autograd/__init__.py',
|
||||
'torch/distributed/benchmarks/benchmark_ddp_rpc.py',
|
||||
'torch/distributed/c10d_logger.py',
|
||||
'torch/distributed/collective_utils.py',
|
||||
'torch/distributed/constants.py',
|
||||
'torch/distributed/distributed_c10d.py',
|
||||
'torch/distributed/elastic/__init__.py',
|
||||
'torch/distributed/elastic/agent/__init__.py',
|
||||
'torch/distributed/elastic/agent/server/__init__.py',
|
||||
'torch/distributed/elastic/agent/server/api.py',
|
||||
'torch/distributed/elastic/agent/server/local_elastic_agent.py',
|
||||
'torch/distributed/elastic/events/__init__.py',
|
||||
'torch/distributed/elastic/events/api.py',
|
||||
'torch/distributed/elastic/events/handlers.py',
|
||||
'torch/distributed/elastic/metrics/__init__.py',
|
||||
'torch/distributed/elastic/metrics/api.py',
|
||||
'torch/distributed/elastic/multiprocessing/__init__.py',
|
||||
'torch/distributed/elastic/multiprocessing/api.py',
|
||||
'torch/distributed/elastic/multiprocessing/errors/__init__.py',
|
||||
'torch/distributed/elastic/multiprocessing/errors/error_handler.py',
|
||||
'torch/distributed/elastic/multiprocessing/errors/handlers.py',
|
||||
'torch/distributed/elastic/multiprocessing/redirects.py',
|
||||
'torch/distributed/elastic/multiprocessing/tail_log.py',
|
||||
'torch/distributed/elastic/rendezvous/__init__.py',
|
||||
'torch/distributed/elastic/rendezvous/api.py',
|
||||
'torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py',
|
||||
'torch/distributed/elastic/rendezvous/dynamic_rendezvous.py',
|
||||
'torch/distributed/elastic/rendezvous/etcd_rendezvous.py',
|
||||
'torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py',
|
||||
'torch/distributed/elastic/rendezvous/etcd_server.py',
|
||||
'torch/distributed/elastic/rendezvous/etcd_store.py',
|
||||
'torch/distributed/elastic/rendezvous/registry.py',
|
||||
'torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py',
|
||||
'torch/distributed/elastic/rendezvous/utils.py',
|
||||
'torch/distributed/elastic/timer/__init__.py',
|
||||
'torch/distributed/elastic/timer/api.py',
|
||||
'torch/distributed/elastic/timer/file_based_local_timer.py',
|
||||
'torch/distributed/elastic/timer/local_timer.py',
|
||||
'torch/distributed/elastic/utils/__init__.py',
|
||||
'torch/distributed/elastic/utils/api.py',
|
||||
'torch/distributed/elastic/utils/data/__init__.py',
|
||||
'torch/distributed/elastic/utils/data/cycling_iterator.py',
|
||||
'torch/distributed/elastic/utils/data/elastic_distributed_sampler.py',
|
||||
'torch/distributed/elastic/utils/distributed.py',
|
||||
'torch/distributed/elastic/utils/log_level.py',
|
||||
'torch/distributed/elastic/utils/logging.py',
|
||||
'torch/distributed/elastic/utils/store.py',
|
||||
'torch/distributed/examples/memory_tracker_example.py',
|
||||
'torch/distributed/launch.py',
|
||||
'torch/distributed/launcher/__init__.py',
|
||||
'torch/distributed/launcher/api.py',
|
||||
'torch/distributed/logging_handlers.py',
|
||||
'torch/distributed/nn/__init__.py',
|
||||
'torch/distributed/nn/api/__init__.py',
|
||||
'torch/distributed/nn/api/remote_module.py',
|
||||
'torch/distributed/nn/functional.py',
|
||||
'torch/distributed/nn/jit/__init__.py',
|
||||
'torch/distributed/nn/jit/instantiator.py',
|
||||
'torch/distributed/nn/jit/templates/__init__.py',
|
||||
'torch/distributed/nn/jit/templates/remote_module_template.py',
|
||||
'torch/distributed/optim/__init__.py',
|
||||
'torch/distributed/optim/apply_optimizer_in_backward.py',
|
||||
'torch/distributed/optim/functional_adadelta.py',
|
||||
'torch/distributed/optim/functional_adagrad.py',
|
||||
'torch/distributed/optim/functional_adam.py',
|
||||
'torch/distributed/optim/functional_adamax.py',
|
||||
'torch/distributed/optim/functional_adamw.py',
|
||||
'torch/distributed/optim/functional_rmsprop.py',
|
||||
'torch/distributed/optim/functional_rprop.py',
|
||||
'torch/distributed/optim/functional_sgd.py',
|
||||
'torch/distributed/optim/named_optimizer.py',
|
||||
'torch/distributed/optim/optimizer.py',
|
||||
'torch/distributed/optim/post_localSGD_optimizer.py',
|
||||
'torch/distributed/optim/utils.py',
|
||||
'torch/distributed/optim/zero_redundancy_optimizer.py',
|
||||
'torch/distributed/remote_device.py',
|
||||
'torch/distributed/rendezvous.py',
|
||||
'torch/distributed/rpc/__init__.py',
|
||||
'torch/distributed/rpc/_testing/__init__.py',
|
||||
'torch/distributed/rpc/_testing/faulty_agent_backend_registry.py',
|
||||
'torch/distributed/rpc/_utils.py',
|
||||
'torch/distributed/rpc/api.py',
|
||||
'torch/distributed/rpc/backend_registry.py',
|
||||
'torch/distributed/rpc/constants.py',
|
||||
'torch/distributed/rpc/functions.py',
|
||||
'torch/distributed/rpc/internal.py',
|
||||
'torch/distributed/rpc/options.py',
|
||||
'torch/distributed/rpc/rref_proxy.py',
|
||||
'torch/distributed/rpc/server_process_global_profiler.py',
|
||||
'torch/distributed/run.py',
|
||||
'torch/distributed/tensor/__init__.py',
|
||||
'torch/distributed/tensor/parallel/__init__.py',
|
||||
'torch/distributed/tensor/parallel/_utils.py',
|
||||
'torch/distributed/tensor/parallel/_view_with_dim_change.py',
|
||||
'torch/distributed/tensor/parallel/api.py',
|
||||
'torch/distributed/tensor/parallel/fsdp.py',
|
||||
'torch/distributed/tensor/parallel/input_reshard.py',
|
||||
'torch/distributed/tensor/parallel/multihead_attention_tp.py',
|
||||
'torch/distributed/tensor/parallel/style.py',
|
||||
'torch/fft/__init__.py',
|
||||
'torch/func/__init__.py',
|
||||
'torch/futures/__init__.py',
|
||||
|
||||
@ -290,7 +290,7 @@ After the final RC is created. The following tasks should be performed :
|
||||
|
||||
* Create validation issue for the release, see for example [Validations for 2.1.2 release](https://github.com/pytorch/pytorch/issues/114904) and perform required validations.
|
||||
|
||||
* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no prerformance regressions.
|
||||
* Run performance tests in [benchmark repository](https://github.com/pytorch/benchmark). Make sure there are no performance regressions.
|
||||
|
||||
* Prepare and stage PyPI binaries for promotion. This is done with this script:
|
||||
[`pytorch/builder:release/pypi/promote_pypi_to_staging.sh`](https://github.com/pytorch/builder/blob/main/release/pypi/promote_pypi_to_staging.sh)
|
||||
@ -429,12 +429,12 @@ need to support these particular versions of software.
|
||||
|
||||
## Operating Systems
|
||||
Supported OS flavors are summarized in the table below:
|
||||
| Operating System family | Architectrue | Notes |
|
||||
| Operating System family | Architecture | Notes |
|
||||
| --- | --- | --- |
|
||||
| Linux | aarch64, x86_64 | Wheels are manylinux2014 compatible, i.e. they should be runnable on any Linux system with glibc-2.17 or above. |
|
||||
| MacOS | arm64 | Builds should be compatible with MacOS 11 (Big Sur) or newer, but are actively tested against MacOS 14 (Sonoma). |
|
||||
| MacOS | x86_64 | Requires MacOS Catalina or above, not supported after 2.2, see https://github.com/pytorch/pytorch/issues/114602 |
|
||||
| Windows | x86_64 | Buils are compatible with Windows-10 or newer. |
|
||||
| Windows | x86_64 | Builds are compatible with Windows-10 or newer. |
|
||||
|
||||
# Submitting Tutorials
|
||||
|
||||
|
||||
@ -473,6 +473,7 @@ endif()
|
||||
|
||||
if(USE_CUDA AND NOT USE_ROCM)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
|
||||
if($ENV{ATEN_STATIC_CUDA})
|
||||
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
|
||||
${CUDA_LIBRARIES}
|
||||
|
||||
@ -303,7 +303,7 @@ Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base,
|
||||
return Tensor();
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx, const c10::optional<Tensor>& min_seqlen, const c10::optional<Tensor>& max_seqlen) {
|
||||
Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx) {
|
||||
auto values = at::_nested_get_values(mutated_view);
|
||||
if (inverse_return_mode != InverseReturnMode::NeverView) {
|
||||
return values;
|
||||
@ -317,12 +317,7 @@ Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const
|
||||
auto lengths = at::_nested_get_lengths(base);
|
||||
auto ragged_idx = at::_nested_get_ragged_idx(base);
|
||||
auto dummy = at::_nested_get_jagged_dummy(base);
|
||||
auto min_seqlen = at::_nested_get_min_seqlen(base);
|
||||
auto max_seqlen = at::_nested_get_max_seqlen(base);
|
||||
auto nt = at::_nested_view_from_jagged(
|
||||
mutated_view, offsets, dummy, lengths, ragged_idx,
|
||||
(min_seqlen.defined() ? c10::optional<Tensor>(min_seqlen) : c10::nullopt),
|
||||
(max_seqlen.defined() ? c10::optional<Tensor>(max_seqlen) : c10::nullopt));
|
||||
auto nt = at::_nested_view_from_jagged(mutated_view, offsets, dummy, lengths, ragged_idx);
|
||||
|
||||
if (inverse_return_mode != InverseReturnMode::NeverView) {
|
||||
return nt;
|
||||
|
||||
@ -765,115 +765,10 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
|
||||
const ElementType& operator[](int idx) const = delete;
|
||||
ElementType& operator[](int idx) = delete;
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator+(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec0 + other._vec0, _vec1 + other._vec1};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator-(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec0 - other._vec0, _vec1 - other._vec1};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator*(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec0 * other._vec0, _vec1 * other._vec1};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator/(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec0 / other._vec0, _vec1 / other._vec1};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator&(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{
|
||||
(vtype)(vecb0() & other.vecb0()), (vtype)(vecb1() & other.vecb1())};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator|(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{
|
||||
(vtype)(vecb0() | other.vecb0()), (vtype)(vecb1() | other.vecb1())};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator^(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{
|
||||
(vtype)(vecb0() ^ other.vecb0()), (vtype)(vecb1() ^ other.vecb1())};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator<<(const Vectorized<T> &other) const {
|
||||
constexpr ElementType max_shift = sizeof(ElementType) * CHAR_BIT;
|
||||
|
||||
ElementType a_array[Vectorized<T>::size()];
|
||||
ElementType b_array[Vectorized<T>::size()];
|
||||
ElementType c_array[Vectorized<T>::size()];
|
||||
|
||||
store(a_array);
|
||||
other.store(b_array);
|
||||
|
||||
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
||||
T shift = b_array[i];
|
||||
if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
|
||||
c_array[i] = 0;
|
||||
} else {
|
||||
c_array[i] = static_cast<std::make_unsigned_t<T>>(a_array[i]) << shift;
|
||||
}
|
||||
}
|
||||
|
||||
return loadu(c_array);
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator>>(const Vectorized<T> &other) const {
|
||||
// right shift value to retain sign bit for signed and no bits for unsigned
|
||||
constexpr ElementType max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v<T>;
|
||||
|
||||
ElementType a_array[Vectorized<T>::size()];
|
||||
ElementType b_array[Vectorized<T>::size()];
|
||||
ElementType c_array[Vectorized<T>::size()];
|
||||
|
||||
store(a_array);
|
||||
other.store(b_array);
|
||||
|
||||
for (int i = 0; i != Vectorized<T>::size(); i++) {
|
||||
T shift = b_array[i];
|
||||
if ((static_cast<std::make_signed_t<T>>(shift) < 0) || (shift >= max_shift)) {
|
||||
c_array[i] = a_array[i] >> max_shift;
|
||||
} else {
|
||||
c_array[i] = a_array[i] >> shift;
|
||||
}
|
||||
}
|
||||
|
||||
return loadu(c_array);
|
||||
}
|
||||
|
||||
Vectorized<T> _not() const {
|
||||
return {(vtype)vec_nor(vecb0(), vecb0()), (vtype)vec_nor(vecb1(), vecb1())};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator==(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{
|
||||
vec_cmpeq(_vec0, other._vec0), vec_cmpeq(_vec1, other._vec1)};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator!=(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{
|
||||
vec_cmpeq(_vec0, other._vec0), vec_cmpeq(_vec1, other._vec1)}
|
||||
._not();
|
||||
}
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator>(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{
|
||||
vec_cmpgt(_vec0, other._vec0), vec_cmpgt(_vec1, other._vec1)};
|
||||
}
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator>=(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{
|
||||
vec_cmpge(_vec0, other._vec0), vec_cmpge(_vec1, other._vec1)};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator<(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{
|
||||
vec_cmplt(_vec0, other._vec0), vec_cmplt(_vec1, other._vec1)};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator<=(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{
|
||||
vec_cmple(_vec0, other._vec0), vec_cmple(_vec1, other._vec1)};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
|
||||
return (*this == other) & Vectorized<T>((T)1.0);
|
||||
}
|
||||
@ -1410,30 +1305,153 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) {
|
||||
return a._not();
|
||||
}
|
||||
#define ZVECTOR_OPERATORS(typex) \
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator+(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec0() + b.vec0(), a.vec1() + b.vec1()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator-(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec0() - b.vec0(), a.vec1() - b.vec1()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator*(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec0() * b.vec0(), a.vec1() * b.vec1()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator/(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec0() / b.vec0(), a.vec1() / b.vec1()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator&(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{ \
|
||||
(Vectorized<typex>::vtype)(a.vecb0() & b.vecb0()), \
|
||||
(Vectorized<typex>::vtype)(a.vecb1() & b.vecb1())}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator|(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{ \
|
||||
(Vectorized<typex>::vtype)(a.vecb0() | b.vecb0()), \
|
||||
(Vectorized<typex>::vtype)(a.vecb1() | b.vecb1())}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator^(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{ \
|
||||
(Vectorized<typex>::vtype)(a.vecb0() ^ b.vecb0()), \
|
||||
(Vectorized<typex>::vtype)(a.vecb1() ^ b.vecb1())}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator==(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{ \
|
||||
vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator!=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{ \
|
||||
vec_cmpeq(a.vec0(), b.vec0()), vec_cmpeq(a.vec1(), b.vec1())} \
|
||||
._not(); \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{ \
|
||||
vec_cmpgt(a.vec0(), b.vec0()), vec_cmpgt(a.vec1(), b.vec1())}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator>=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{ \
|
||||
vec_cmpge(a.vec0(), b.vec0()), vec_cmpge(a.vec1(), b.vec1())}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{ \
|
||||
vec_cmplt(a.vec0(), b.vec0()), vec_cmplt(a.vec1(), b.vec1())}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator<=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{ \
|
||||
vec_cmple(a.vec0(), b.vec0()), vec_cmple(a.vec1(), b.vec1())}; \
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) {
|
||||
return a._not();
|
||||
}
|
||||
ZVECTOR_OPERATORS(float)
|
||||
ZVECTOR_OPERATORS(double)
|
||||
ZVECTOR_OPERATORS(int8_t)
|
||||
ZVECTOR_OPERATORS(uint8_t)
|
||||
ZVECTOR_OPERATORS(uint16_t)
|
||||
ZVECTOR_OPERATORS(int16_t)
|
||||
ZVECTOR_OPERATORS(int32_t)
|
||||
ZVECTOR_OPERATORS(int64_t)
|
||||
|
||||
template <>
|
||||
inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) {
|
||||
return a._not();
|
||||
}
|
||||
#undef ZVECTOR_OPERATORS
|
||||
|
||||
template <>
|
||||
inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) {
|
||||
return a._not();
|
||||
}
|
||||
#define ZVECTOR_OPERATORS(typex) \
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator<<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
constexpr Vectorized<typex>::ElementType max_shift \
|
||||
= sizeof(Vectorized<typex>::ElementType) * CHAR_BIT; \
|
||||
\
|
||||
Vectorized<typex>::ElementType a_array[Vectorized<typex>::size()]; \
|
||||
Vectorized<typex>::ElementType b_array[Vectorized<typex>::size()]; \
|
||||
Vectorized<typex>::ElementType c_array[Vectorized<typex>::size()]; \
|
||||
\
|
||||
a.store(a_array); \
|
||||
b.store(b_array); \
|
||||
\
|
||||
for (int i = 0; i != Vectorized<typex>::size(); i++) { \
|
||||
typex shift = b_array[i]; \
|
||||
if ((static_cast<std::make_signed_t<typex>>(shift) < 0) || (shift >= max_shift)) { \
|
||||
c_array[i] = 0; \
|
||||
} else { \
|
||||
c_array[i] = static_cast<std::make_unsigned_t<typex>>(a_array[i]) << shift; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
return Vectorized<typex>::loadu(c_array); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator>>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
/* right shift value to retain sign bit for signed and no bits for unsigned */ \
|
||||
constexpr Vectorized<typex>::ElementType max_shift \
|
||||
= sizeof(typex) * CHAR_BIT - std::is_signed_v<typex>; \
|
||||
\
|
||||
Vectorized<typex>::ElementType a_array[Vectorized<typex>::size()]; \
|
||||
Vectorized<typex>::ElementType b_array[Vectorized<typex>::size()]; \
|
||||
Vectorized<typex>::ElementType c_array[Vectorized<typex>::size()]; \
|
||||
\
|
||||
a.store(a_array); \
|
||||
b.store(b_array); \
|
||||
\
|
||||
for (int i = 0; i != Vectorized<typex>::size(); i++) { \
|
||||
typex shift = b_array[i]; \
|
||||
if ((static_cast<std::make_signed_t<typex>>(shift) < 0) || (shift >= max_shift)) { \
|
||||
c_array[i] = a_array[i] >> max_shift; \
|
||||
} else { \
|
||||
c_array[i] = a_array[i] >> shift; \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
return Vectorized<typex>::loadu(c_array); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline Vectorized<typex> operator~(const Vectorized<typex>& a) { \
|
||||
return a._not(); \
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<uint8_t> operator~(const Vectorized<uint8_t>& a) {
|
||||
return a._not();
|
||||
}
|
||||
ZVECTOR_OPERATORS(int8_t)
|
||||
ZVECTOR_OPERATORS(uint8_t)
|
||||
ZVECTOR_OPERATORS(uint16_t)
|
||||
ZVECTOR_OPERATORS(int16_t)
|
||||
ZVECTOR_OPERATORS(int32_t)
|
||||
ZVECTOR_OPERATORS(int64_t)
|
||||
|
||||
#undef ZVECTOR_OPERATORS
|
||||
|
||||
#define DEFINE_MAXMIN_FUNCS(operand_type) \
|
||||
template <> \
|
||||
@ -1976,55 +1994,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
|
||||
return Vectorized<U>{ret};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator+(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec + other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator-(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec - other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator*(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec * other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator/(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec / other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator&(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec & other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator|(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec | other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator^(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec ^ other._vec};
|
||||
}
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator==(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec == other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator!=(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec != other._vec};
|
||||
}
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator>(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec > other._vec};
|
||||
}
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator>=(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec >= other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator<(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec < other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator<=(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec <= other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec.eq(other._vec)};
|
||||
}
|
||||
@ -2061,6 +2030,72 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
|
||||
}
|
||||
};
|
||||
|
||||
#define ZVECTOR_OPERATORS(typex) \
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator+(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() + b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator-(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() - b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator*(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() * b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator/(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() / b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator&(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() & b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator|(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() | b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator^(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() ^ b.vec()}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator==(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() == b.vec()}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator!=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() != b.vec()}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() > b.vec()}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator>=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() >= b.vec()}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() < b.vec()}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator<=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() <= b.vec()}; \
|
||||
}
|
||||
|
||||
ZVECTOR_OPERATORS(c10::qint32)
|
||||
ZVECTOR_OPERATORS(c10::qint8)
|
||||
ZVECTOR_OPERATORS(c10::quint8)
|
||||
|
||||
#undef ZVECTOR_OPERATORS
|
||||
|
||||
DEFINE_CLAMP_MAXMIN_FUNCS(c10::quint8)
|
||||
DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint8)
|
||||
DEFINE_CLAMP_MAXMIN_FUNCS(c10::qint32)
|
||||
@ -2364,35 +2399,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
return Vectorized<T>{a00, a01};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator+(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec + other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator-(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec - other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> inline operator*(const Vectorized<T>& b) const {
|
||||
//(a + bi) * (c + di) = (ac - bd) + (ad + bc)i
|
||||
vinner_type bv = b.vec();
|
||||
#if !defined(ZVECTOR_SIMULATE_X86_MULT)
|
||||
// this is more z arch friendly than simulating horizontal from x86
|
||||
vinner_type vi = bv.mergeo();
|
||||
vinner_type vr = bv.mergee();
|
||||
vi = vi ^ rsign_mask<underline_type>();
|
||||
vinner_type ret = _vec * vr;
|
||||
vinner_type vx_swapped = _vec.swapped();
|
||||
ret = fmadd(vx_swapped, vi, ret);
|
||||
#else
|
||||
vinner_type ac_bd = _vec * b;
|
||||
vinner_type d_c = bv.swapped();
|
||||
d_c = d_c ^ isign_mask<underline_type>();
|
||||
vinner_type ad_bc = _vec * d_c;
|
||||
vinner_type ret = vinner_type::horizontal_sub_perm(ac_bd, ad_bc);
|
||||
#endif
|
||||
return Vectorized<T>{ret};
|
||||
}
|
||||
|
||||
template <
|
||||
typename U = T,
|
||||
std::enable_if_t<std::is_same<U, c10::complex<float>>::value, int> = 0>
|
||||
@ -2418,29 +2424,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
return { v0, v1 };
|
||||
}
|
||||
|
||||
Vectorized<T> inline operator/(const Vectorized<T>& b) const {
|
||||
// Unfortunately, this breaks some tests
|
||||
// Implement it like it's done for avx2
|
||||
auto fabs_cd = b.vec().abs(); // |c| |d|
|
||||
auto fabs_dc = fabs_cd.swapped(); // |d| |c|
|
||||
auto scale = vinner_type {1.0} / maximum(fabs_cd, fabs_dc); // 1/sc 1/sc
|
||||
auto a2 = vec() * scale; // a/sc b/sc
|
||||
auto b2 = b.vec() * scale; // c/sc d/sc
|
||||
auto acbd2 = a2 * b2; // ac/sc^2 bd/sc^2
|
||||
|
||||
auto dc2 = b2.swapped(); // d/sc c/sc
|
||||
dc2 = Vectorized<T>::real_neg(dc2); // -d/|c,d| c/sc
|
||||
auto adbc2 = a2 * dc2; // -ad/sc^2 bc/sc^2
|
||||
auto sum1 = acbd2 + acbd2.swapped(); // (ac+bd)/sc^2 (ac+bd)/sc^2
|
||||
auto sum2 = adbc2 + adbc2.swapped(); // (bc-ad)/sc^2 (bc-ad)/sc^2
|
||||
auto res2 = vinner_type::mergee(sum1, sum2); // (ac+bd)/sc^2 (bc-ad)/sc^2
|
||||
|
||||
// get the denominator
|
||||
auto denom2 = Vectorized<T>{b2}.abs_2_(); // (c^2+d^2)/sc^2 (c^2+d^2)/sc^2
|
||||
res2 = res2 / denom2;
|
||||
return Vectorized<T>{ res2 };
|
||||
}
|
||||
|
||||
Vectorized<T> angle2_() const {
|
||||
auto b_a = _vec.swapped(); // b a
|
||||
return Vectorized<T>{_vec.atan2(b_a).swapped()};
|
||||
@ -2528,25 +2511,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
return Vectorized<T>{_vec.trunc()};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator&(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec & other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator|(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec | other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator^(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec ^ other._vec};
|
||||
}
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator==(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec == other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE operator!=(const Vectorized<T>& other) const {
|
||||
return Vectorized<T>{_vec != other._vec};
|
||||
}
|
||||
|
||||
Vectorized<T> C10_ALWAYS_INLINE eq(const Vectorized<T>& other) const {
|
||||
auto eq = _vec.eq(other._vec); // compares real and imag individually
|
||||
// If both real numbers and imag numbers are equal, then the complex numbers are equal
|
||||
@ -2648,22 +2612,6 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
return sqrt().reciprocal();
|
||||
}
|
||||
|
||||
Vectorized<T> operator<(const Vectorized<T>& other) const {
|
||||
TORCH_CHECK(false, "not supported for complex numbers");
|
||||
}
|
||||
|
||||
Vectorized<T> operator<=(const Vectorized<T>& other) const {
|
||||
TORCH_CHECK(false, "not supported for complex numbers");
|
||||
}
|
||||
|
||||
Vectorized<T> operator>(const Vectorized<T>& other) const {
|
||||
TORCH_CHECK(false, "not supported for complex numbers");
|
||||
}
|
||||
|
||||
Vectorized<T> operator>=(const Vectorized<T>& other) const {
|
||||
TORCH_CHECK(false, "not supported for complex numbers");
|
||||
}
|
||||
|
||||
Vectorized<T> lt(const Vectorized<T>& other) const {
|
||||
TORCH_CHECK(false, "not supported for complex numbers");
|
||||
}
|
||||
@ -2681,6 +2629,101 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_complex<T>()>> {
|
||||
}
|
||||
};
|
||||
|
||||
#define ZVECTOR_OPERATORS(typex) \
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator+(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() + b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator-(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() - b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> inline operator*(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
/* (a + bi) * (c + di) = (ac - bd) + (ad + bc)i */ \
|
||||
Vectorized<typex>::vinner_type bv = b.vec(); \
|
||||
\
|
||||
/* this is more z arch friendly than simulating horizontal from x86 */ \
|
||||
Vectorized<typex>::vinner_type vi = bv.mergeo(); \
|
||||
Vectorized<typex>::vinner_type vr = bv.mergee(); \
|
||||
vi = vi ^ Vectorized<typex>::vinner_type(rsign_mask<Vectorized<typex>::underline_type>()); \
|
||||
Vectorized<typex>::vinner_type ret = a.vec() * vr; \
|
||||
Vectorized<typex>::vinner_type vx_swapped = a.vec().swapped(); \
|
||||
ret = fmadd(vx_swapped, vi, ret); \
|
||||
\
|
||||
return Vectorized<typex>{ret}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> inline operator/(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
/* Unfortunately, this breaks some tests */ \
|
||||
/* Implement it like it's done for avx2 */ \
|
||||
auto fabs_cd = b.vec().abs(); /* |c| |d| */ \
|
||||
auto fabs_dc = fabs_cd.swapped(); /* |d| |c| */ \
|
||||
auto scale = Vectorized<typex>::vinner_type {1.0} / maximum(fabs_cd, fabs_dc); /* 1/sc 1/sc */ \
|
||||
auto a2 = a.vec() * scale; /* a/sc b/sc */ \
|
||||
auto b2 = b.vec() * scale; /* c/sc d/sc */ \
|
||||
auto acbd2 = a2 * b2; /* ac/sc^2 bd/sc^2 */ \
|
||||
\
|
||||
auto dc2 = b2.swapped(); /* d/sc c/sc */ \
|
||||
dc2 = Vectorized<typex>::real_neg(dc2); /* -d/|c,d| c/sc */ \
|
||||
auto adbc2 = a2 * dc2; /* -ad/sc^2 bc/sc^2 */ \
|
||||
auto sum1 = acbd2 + acbd2.swapped(); /* (ac+bd)/sc^2 (ac+bd)/sc^2 */ \
|
||||
auto sum2 = adbc2 + adbc2.swapped(); /* (bc-ad)/sc^2 (bc-ad)/sc^2 */ \
|
||||
auto res2 = Vectorized<typex>::vinner_type::mergee(sum1, sum2); /* (ac+bd)/sc^2 (bc-ad)/sc^2 */ \
|
||||
\
|
||||
/* get the denominator */ \
|
||||
Vectorized<typex>::vinner_type denom2 = Vectorized<typex>{b2}.abs_2_(); /* (c^2+d^2)/sc^2 (c^2+d^2)/sc^2 */ \
|
||||
res2 = res2 / denom2; \
|
||||
return Vectorized<typex>{ res2 }; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator&(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() & b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator|(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() | b.vec()}; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator^(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() ^ b.vec()}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator==(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() == b.vec()}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator!=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
return Vectorized<typex>{a.vec() != b.vec()}; \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator<(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
TORCH_CHECK(false, "not supported for complex numbers"); \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator<=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
TORCH_CHECK(false, "not supported for complex numbers"); \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator>(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
TORCH_CHECK(false, "not supported for complex numbers"); \
|
||||
} \
|
||||
\
|
||||
Vectorized<typex> C10_ALWAYS_INLINE operator>=(const Vectorized<typex>& a, const Vectorized<typex>& b) { \
|
||||
TORCH_CHECK(false, "not supported for complex numbers"); \
|
||||
}
|
||||
|
||||
ZVECTOR_OPERATORS(c10::complex<float>)
|
||||
ZVECTOR_OPERATORS(c10::complex<double>)
|
||||
|
||||
#undef ZVECTOR_OPERATORS
|
||||
|
||||
template <typename T, std::enable_if_t<(sizeof(T) == 8), int> = 0>
|
||||
std::pair<Vectorized<T>, Vectorized<T>> inline inner_interleave2(
|
||||
const Vectorized<T>& a,
|
||||
|
||||
@ -334,7 +334,13 @@ static inline __device__ void gpuAtomicAddNoReturn(double *address, double val)
|
||||
|
||||
/* Special case fp32 atomic. */
|
||||
#if defined(USE_ROCM)
|
||||
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { atomicAddNoRet(address, val); }
|
||||
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) {
|
||||
#if defined(__gfx908__)
|
||||
atomicAddNoRet(address, val);
|
||||
#else
|
||||
(void)unsafeAtomicAdd(address, val);
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
static inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
|
||||
#endif
|
||||
|
||||
@ -152,9 +152,6 @@ void CUDAGeneratorState::register_graph(cuda::CUDAGraph* graph) {
|
||||
* Unregisters a CUDA graph from the RNG state.
|
||||
*/
|
||||
void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
|
||||
// Ensures that the RNG state is not currently being captured.
|
||||
at::cuda::assertNotCapturing(
|
||||
"Cannot unregister the state during capturing stage.");
|
||||
// Verify the graph was previously registered.
|
||||
TORCH_CHECK(
|
||||
registered_graphs_.find(graph) != registered_graphs_.end(),
|
||||
|
||||
@ -170,6 +170,43 @@ CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *);
|
||||
CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int);
|
||||
CUDA_STUB3(cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction);
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
|
||||
CUresult CUDAAPI
|
||||
cuTensorMapEncodeTiled(
|
||||
CUtensorMap* tensorMap,
|
||||
CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank,
|
||||
void* globalAddress,
|
||||
const cuuint64_t* globalDim,
|
||||
const cuuint64_t* globalStrides,
|
||||
const cuuint32_t* boxDim,
|
||||
const cuuint32_t* elementStrides,
|
||||
CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle,
|
||||
CUtensorMapL2promotion l2Promotion,
|
||||
CUtensorMapFloatOOBfill oobFill) {
|
||||
auto fn = reinterpret_cast<decltype(&cuTensorMapEncodeTiled)>(
|
||||
getCUDALibrary().sym(__func__));
|
||||
if (!fn)
|
||||
throw std::runtime_error("Can't get cuTensorMapEncodeTiled");
|
||||
lazyNVRTC.cuTensorMapEncodeTiled = fn;
|
||||
return fn(
|
||||
tensorMap,
|
||||
tensorDataType,
|
||||
tensorRank,
|
||||
globalAddress,
|
||||
globalDim,
|
||||
globalStrides,
|
||||
boxDim,
|
||||
elementStrides,
|
||||
interleave,
|
||||
swizzle,
|
||||
l2Promotion,
|
||||
oobFill);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Irregularly shaped functions
|
||||
CUresult CUDAAPI cuLaunchKernel(CUfunction f,
|
||||
unsigned int gridDimX,
|
||||
|
||||
@ -59,16 +59,25 @@ namespace at { namespace cuda {
|
||||
_(cuLinkAddData) \
|
||||
_(cuLinkComplete) \
|
||||
_(cuFuncSetAttribute) \
|
||||
_(cuFuncGetAttribute)
|
||||
_(cuFuncGetAttribute) \
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
|
||||
#define AT_FORALL_NVRTC_EXTENDED(_) \
|
||||
AT_FORALL_NVRTC_BASE(_) \
|
||||
_(cuTensorMapEncodeTiled)
|
||||
#else
|
||||
#define AT_FORALL_NVRTC_EXTENDED(_) \
|
||||
AT_FORALL_NVRTC_BASE(_)
|
||||
#endif
|
||||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
|
||||
#define AT_FORALL_NVRTC(_) \
|
||||
AT_FORALL_NVRTC_BASE(_) \
|
||||
AT_FORALL_NVRTC_EXTENDED(_) \
|
||||
_(nvrtcGetCUBINSize) \
|
||||
_(nvrtcGetCUBIN)
|
||||
#else
|
||||
#define AT_FORALL_NVRTC(_) \
|
||||
AT_FORALL_NVRTC_BASE(_)
|
||||
AT_FORALL_NVRTC_EXTENDED(_)
|
||||
#endif
|
||||
|
||||
#else
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
#include <cstdint>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
@ -10,6 +14,7 @@
|
||||
#include <ATen/cuda/tunable/TunableGemm.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <c10/util/MaybeOwned.h>
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -819,24 +824,106 @@ static bool _scaled_mm_allowed_device() {
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace{
|
||||
|
||||
enum class ScalingType {
|
||||
TensorWise,
|
||||
RowWise,
|
||||
Error
|
||||
};
|
||||
/*
|
||||
* Scaling Type Determination:
|
||||
* ---------------------------
|
||||
* Conditions and corresponding Scaling Types:
|
||||
*
|
||||
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
|
||||
* - Returns TensorWise.
|
||||
*
|
||||
* - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
|
||||
* - Returns RowWise.
|
||||
*
|
||||
* - Otherwise:
|
||||
* - Returns Error.
|
||||
*/
|
||||
|
||||
// Validates the scale tensors to scaled_mm
|
||||
// And returns the type of scaling/which kernel to use
|
||||
ScalingType get_scaling_type(
|
||||
const at::Tensor& scale_a,
|
||||
const at::Tensor& scale_b,
|
||||
int64_t dim_m,
|
||||
int64_t dim_n) {
|
||||
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
|
||||
TORCH_CHECK(
|
||||
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"Both scale_a and scale_b must be float (fp32) tensors.");
|
||||
|
||||
|
||||
// Check the singluar scale case for per-tensor scaling
|
||||
if (scale_a.numel() == 1 && scale_b.numel() == 1) {
|
||||
return ScalingType::TensorWise;
|
||||
} else if (scale_a.dim() == 1 && scale_a.size(0) == dim_m) {
|
||||
// Check the per-row scaling case
|
||||
#if !defined(USE_ROCM) && !defined(_MSC_VER) || \
|
||||
(defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
TORCH_CHECK(
|
||||
scale_a.dim() == 1 && scale_b.dim() == 1,
|
||||
"Both scale_a and scale_b must be 1-dimensional tensors");
|
||||
TORCH_CHECK(
|
||||
scale_b.size(0) == dim_n,
|
||||
"For row-wise scaling, scale_b must have size ",
|
||||
dim_n,
|
||||
" but got ",
|
||||
scale_b.size(0),
|
||||
".");
|
||||
TORCH_CHECK(
|
||||
scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"Both scale_a and scale_b must be contiguous.");
|
||||
return ScalingType::RowWise;
|
||||
#else
|
||||
TORCH_CHECK(false, "Per-row scaling is not supported for this platform!");
|
||||
return ScalingType::Error;
|
||||
#endif // !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) &&
|
||||
// ROCM_VERSION >= 60000)
|
||||
} else {
|
||||
// Prettier Error Case messaging
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"For row-wise scaling, scale_a must be size ",
|
||||
dim_m,
|
||||
" but got ",
|
||||
scale_a.numel(),
|
||||
" and scale_b must be size ",
|
||||
dim_n,
|
||||
" but got ",
|
||||
scale_b.numel(),
|
||||
".");
|
||||
// Unreachable
|
||||
return ScalingType::RowWise;
|
||||
}
|
||||
return ScalingType::Error;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
|
||||
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
|
||||
// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
|
||||
// Known limitations:
|
||||
// - Only works if mat1 is row-major and mat2 is column-major
|
||||
// - Only works if matrices sizes are divisible by 32
|
||||
//
|
||||
// - If 1-dimensional tensors are used then scale_a should be size = mat1.size(0)
|
||||
// and scale_b should have size = to mat2.size(1)
|
||||
// Arguments:
|
||||
// - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
|
||||
// - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
|
||||
// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
|
||||
// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
|
||||
// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type
|
||||
// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
|
||||
// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
|
||||
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
|
||||
// - `use_fast_accum`: if true, enables fast float8 accumulation
|
||||
// - `out`: a reference to the output tensor
|
||||
// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace
|
||||
|
||||
Tensor&
|
||||
_scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
@ -855,10 +942,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
TORCH_CHECK((scale_a.numel() == 1 && scale_a.scalar_type() == kFloat),
|
||||
"scale_a must be float scalar");
|
||||
TORCH_CHECK((scale_b.numel() == 1 && scale_b.scalar_type() == kFloat),
|
||||
"scale_b must be a float scalar");
|
||||
|
||||
// Check what type of scaling we are doing based on inputs
|
||||
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1));
|
||||
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
|
||||
|
||||
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
||||
"scale_result must be a float scalar");
|
||||
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
|
||||
@ -899,11 +987,25 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
{scale_result_, "scale_result", 6}};
|
||||
checkAllSameGPU(__func__, targs);
|
||||
}
|
||||
|
||||
// Validation checks have passed lets resize the output to actual size
|
||||
IntArrayRef mat1_sizes = mat1.sizes();
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
|
||||
// We are doing row-wise scaling
|
||||
if (scaling_choice == ScalingType::RowWise) {
|
||||
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling.");
|
||||
at::cuda::detail::f8f8bf16_rowwise(
|
||||
mat1,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
return out;
|
||||
}
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, out);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
||||
536
aten/src/ATen/native/cuda/RowwiseScaledMM.cu
Normal file
536
aten/src/ATen/native/cuda/RowwiseScaledMM.cu
Normal file
@ -0,0 +1,536 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
|
||||
// Determine if the architecture supports rowwise scaled mm
|
||||
// Currenlty failing on windows with: https://github.com/NVIDIA/cutlass/issues/1571
|
||||
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000
|
||||
|
||||
#define BUILD_ROWWISE_FP8_KERNEL
|
||||
#endif
|
||||
|
||||
#if defined(BUILD_ROWWISE_FP8_KERNEL)
|
||||
|
||||
// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader
|
||||
static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
|
||||
CUtensorMap* tensorMap,
|
||||
CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank,
|
||||
void* globalAddress,
|
||||
const cuuint64_t* globalDim,
|
||||
const cuuint64_t* globalStrides,
|
||||
const cuuint32_t* boxDim,
|
||||
const cuuint32_t* elementStrides,
|
||||
CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle,
|
||||
CUtensorMapL2promotion l2Promotion,
|
||||
CUtensorMapFloatOOBfill oobFill) {
|
||||
return at::globalContext().getNVRTC().cuTensorMapEncodeTiled(
|
||||
tensorMap,
|
||||
tensorDataType,
|
||||
tensorRank,
|
||||
globalAddress,
|
||||
globalDim,
|
||||
globalStrides,
|
||||
boxDim,
|
||||
elementStrides,
|
||||
interleave,
|
||||
swizzle,
|
||||
l2Promotion,
|
||||
oobFill);
|
||||
}
|
||||
|
||||
|
||||
#include <cutlass/core_io.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
#include <cutlass/half.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/trace.h>
|
||||
#include <cutlass/util/host_tensor.h>
|
||||
|
||||
// Rename the global function symbol
|
||||
#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled
|
||||
#include <cute/tensor.hpp>
|
||||
#undef cuTensorMapEncodeTiled
|
||||
// Set everything back to normal
|
||||
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
|
||||
#include <cute/atom/mma_atom.hpp>
|
||||
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
|
||||
namespace {
|
||||
// Cutlass rowwise kernel
|
||||
template <
|
||||
int TB_M,
|
||||
int TB_N,
|
||||
int TB_K,
|
||||
int TBS_M,
|
||||
int TBS_N,
|
||||
int TBS_K,
|
||||
bool PONG,
|
||||
bool FAST_ACCUM,
|
||||
bool USE_BIAS,
|
||||
typename INPUT_DTYPE,
|
||||
typename BIAS_DTYPE>
|
||||
void f8f8bf16_rowwise_impl(
|
||||
at::Tensor XQ, // FP8
|
||||
at::Tensor WQ, // FP8
|
||||
at::Tensor x_scale,
|
||||
at::Tensor w_scale,
|
||||
c10::optional<at::Tensor> bias,
|
||||
at::Tensor out) {
|
||||
int M = XQ.size(0);
|
||||
int N = WQ.size(1);
|
||||
int K = XQ.size(1);
|
||||
|
||||
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
|
||||
TORCH_CHECK(
|
||||
WQ.is_cuda() && WQ.ndimension() == 2 && WQ.stride(1) == WQ.size(0) &&
|
||||
WQ.stride(0) == 1);
|
||||
|
||||
// auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
|
||||
|
||||
using ElementInputA = INPUT_DTYPE;
|
||||
using LayoutInputA = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentInputA = 16 / sizeof(ElementInputA);
|
||||
|
||||
using ElementInputB = cutlass::float_e4m3_t;
|
||||
using LayoutInputB = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentInputB = 16 / sizeof(ElementInputB);
|
||||
|
||||
using ElementBias = BIAS_DTYPE;
|
||||
|
||||
using ElementOutput = cutlass::bfloat16_t;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentOutput = 16 / sizeof(ElementOutput);
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementComputeEpilogue = float;
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
|
||||
// supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using TileShape = cute::Shape<
|
||||
cute::Int<TB_M>,
|
||||
cute::Int<TB_N>,
|
||||
cute::Int<TB_K>>; // Threadblock-level
|
||||
// tile size
|
||||
using ClusterShape = cute::Shape<
|
||||
cute::Int<TBS_M>,
|
||||
cute::Int<TBS_N>,
|
||||
cute::Int<TBS_K>>; // Shape of the
|
||||
// threadblocks in a
|
||||
// cluster
|
||||
using KernelSchedule = cutlass::gemm::collective::
|
||||
KernelScheduleAuto; // Kernel to launch based on the default setting in
|
||||
// the Collective Builder
|
||||
|
||||
// Implement rowwise scaling epilogue.
|
||||
using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
|
||||
|
||||
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
PONG ? 2 : 1,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
PONG ? 2 : 1,
|
||||
TileShape,
|
||||
ElementBias,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
ElementComputeEpilogue, // First stage output type.
|
||||
ElementComputeEpilogue, // First stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, WScale, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
cute::conditional_t< // Second stage output type.
|
||||
USE_BIAS,
|
||||
ElementBias,
|
||||
ElementOutput>,
|
||||
ElementComputeEpilogue, // Second stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute1 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, XScale, EVTCompute0>;
|
||||
|
||||
using ComputeBias = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::plus,
|
||||
ElementOutput, // Final (optional) stage output type.
|
||||
ElementBias, // Final stage input types.
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeBias =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeBias, Bias, EVTCompute1>;
|
||||
|
||||
using EpilogueEVT =
|
||||
cute::conditional_t<USE_BIAS, EVTComputeBias, EVTCompute1>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementComputeEpilogue,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
AlignmentOutput,
|
||||
ElementOutput,
|
||||
LayoutOutput,
|
||||
AlignmentOutput,
|
||||
cutlass::epilogue::TmaWarpSpecialized,
|
||||
EpilogueEVT>::CollectiveOp;
|
||||
|
||||
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using FastDefaultSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using FastPongSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using SlowAccum = cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
|
||||
using FastAccum =
|
||||
cute::conditional_t<PONG, FastPongSchedule, FastDefaultSchedule>;
|
||||
using MainLoopSchedule =
|
||||
cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementInputA,
|
||||
LayoutInputA,
|
||||
AlignmentInputA,
|
||||
ElementInputB,
|
||||
LayoutInputB,
|
||||
AlignmentInputB,
|
||||
ElementAccumulator,
|
||||
TileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainLoopSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideInputA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideInputB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideOutput = typename Gemm::GemmKernel::StrideC;
|
||||
|
||||
StrideInputA stride_a = cutlass::make_cute_packed_stride(
|
||||
StrideInputA{}, cute::make_shape(M, K, 1));
|
||||
StrideInputB stride_b = cutlass::make_cute_packed_stride(
|
||||
StrideInputB{}, cute::make_shape(N, K, 1));
|
||||
StrideOutput stride_output = cutlass::make_cute_packed_stride(
|
||||
StrideOutput{}, cute::make_shape(M, N, 1));
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
{reinterpret_cast<ElementInputA*>(XQ.data_ptr()),
|
||||
stride_a,
|
||||
reinterpret_cast<ElementInputB*>(WQ.data_ptr()),
|
||||
stride_b},
|
||||
{{}, // Epilogue thread we populate below.
|
||||
(ElementOutput*)out.data_ptr<at::BFloat16>(),
|
||||
stride_output,
|
||||
(ElementOutput*)out.data_ptr<at::BFloat16>(),
|
||||
stride_output}};
|
||||
|
||||
if constexpr (USE_BIAS) {
|
||||
arguments.epilogue.thread = {
|
||||
{reinterpret_cast<ElementBias*>(bias.value().data_ptr())}, // bias
|
||||
// compute_1
|
||||
{
|
||||
{reinterpret_cast<ElementComputeEpilogue*>(
|
||||
x_scale.data_ptr())}, // x_scale
|
||||
// compute_0
|
||||
{
|
||||
{reinterpret_cast<ElementComputeEpilogue*>(
|
||||
w_scale.data_ptr())}, // w_scale
|
||||
{}, // Accumulator
|
||||
{} // Multiplies
|
||||
},
|
||||
{}, // Multiplies
|
||||
},
|
||||
{}, // Plus
|
||||
};
|
||||
} else {
|
||||
arguments.epilogue.thread = {
|
||||
{reinterpret_cast<ElementComputeEpilogue*>(
|
||||
x_scale.data_ptr())}, // x_scale
|
||||
// compute_0
|
||||
{
|
||||
{reinterpret_cast<ElementComputeEpilogue*>(
|
||||
w_scale.data_ptr())}, // w_scale
|
||||
{}, // Accumulator
|
||||
{} // Multiplies
|
||||
},
|
||||
{}, // Multiplies
|
||||
};
|
||||
}
|
||||
|
||||
Gemm gemm;
|
||||
|
||||
// Using the arguments, query for extra workspace required for matrix
|
||||
// multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm.can_implement(arguments);
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error("cutlass cannot implement");
|
||||
}
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm.initialize(arguments, workspace.get());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error("cutlass cannot initialize");
|
||||
}
|
||||
|
||||
status = gemm(at::cuda::getCurrentCUDAStream());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error(
|
||||
std::string("cutlass cannot run") +
|
||||
cutlass::cutlassGetStatusString(status));
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
// FP8 Rowwise Cutlass kernel dispatch.
|
||||
enum class KernelMode { Small, Large, Default };
|
||||
|
||||
KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
|
||||
auto M = XQ.size(0);
|
||||
auto K = XQ.size(1);
|
||||
auto N = WQ.size(0);
|
||||
// Use a large kernel if at least two shapes are large....
|
||||
bool use_large_kernel =
|
||||
((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) ||
|
||||
(K >= 2048 && N >= 2048));
|
||||
if (M <= 128 || N <= 128) {
|
||||
return KernelMode::Small;
|
||||
} else if (use_large_kernel) {
|
||||
return KernelMode::Large;
|
||||
} else {
|
||||
return KernelMode::Default;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InputDType, bool FastAccum, bool UseBias, typename BiasDType>
|
||||
void dispatch_fp8_rowwise_kernel(
|
||||
at::Tensor XQ,
|
||||
at::Tensor WQ,
|
||||
at::Tensor x_scale,
|
||||
at::Tensor w_scale,
|
||||
c10::optional<at::Tensor> bias,
|
||||
at::Tensor out) {
|
||||
KernelMode kernel = get_kernel_mode(XQ, WQ);
|
||||
if (kernel == KernelMode::Small) {
|
||||
return f8f8bf16_rowwise_impl<
|
||||
64,
|
||||
128,
|
||||
128,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
FastAccum,
|
||||
UseBias,
|
||||
InputDType,
|
||||
BiasDType>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else if (kernel == KernelMode::Large) {
|
||||
return f8f8bf16_rowwise_impl<
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
true,
|
||||
FastAccum,
|
||||
UseBias,
|
||||
InputDType,
|
||||
BiasDType>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
return f8f8bf16_rowwise_impl<
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
false,
|
||||
FastAccum,
|
||||
UseBias,
|
||||
InputDType,
|
||||
BiasDType>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif // !defined(USE_ROCM)
|
||||
|
||||
namespace at::cuda::detail {
|
||||
void f8f8bf16_rowwise(
|
||||
at::Tensor XQ, // FP8
|
||||
at::Tensor WQ, // FP8
|
||||
at::Tensor x_scale, // FP32
|
||||
at::Tensor w_scale, // FP32
|
||||
c10::optional<at::Tensor> bias, // BF16
|
||||
bool use_fast_accum,
|
||||
at::Tensor& out) {
|
||||
#if defined(BUILD_ROWWISE_FP8_KERNEL)
|
||||
// Check datatypes.
|
||||
TORCH_CHECK(
|
||||
x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat,
|
||||
"Scale tensors must be float32.");
|
||||
if (bias.has_value()) {
|
||||
TORCH_CHECK(
|
||||
bias.value().dtype() == at::kFloat ||
|
||||
bias.value().dtype() == at::kBFloat16,
|
||||
"Bias type must be bfloat16 or float32 if provided.");
|
||||
}
|
||||
// Extract problem size.
|
||||
int M = XQ.size(0);
|
||||
int N = WQ.size(1);
|
||||
int K = XQ.size(1);
|
||||
|
||||
bool use_bias = bias.has_value();
|
||||
bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16;
|
||||
|
||||
// Templatize based on input dtype.
|
||||
bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2;
|
||||
TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "For row-wise scaling the second input is required to be a float8_e4m3fn dtype.");
|
||||
|
||||
if (use_bias) {
|
||||
if (bf16_bias) {
|
||||
if (use_fast_accum) {
|
||||
if (use_e5m2) {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e5m2_t,
|
||||
true,
|
||||
true,
|
||||
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e4m3_t,
|
||||
true,
|
||||
true,
|
||||
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
}
|
||||
} else {
|
||||
if (use_e5m2) {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e5m2_t,
|
||||
false,
|
||||
true,
|
||||
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e4m3_t,
|
||||
false,
|
||||
true,
|
||||
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (use_fast_accum) {
|
||||
if (use_e5m2) {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e5m2_t,
|
||||
true,
|
||||
true,
|
||||
float>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e4m3_t,
|
||||
true,
|
||||
true,
|
||||
float>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
}
|
||||
} else {
|
||||
if (use_e5m2) {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e5m2_t,
|
||||
false,
|
||||
true,
|
||||
float>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e4m3_t,
|
||||
false,
|
||||
true,
|
||||
float>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (use_fast_accum) {
|
||||
if (use_e5m2) {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e5m2_t,
|
||||
true,
|
||||
false,
|
||||
float>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e4m3_t,
|
||||
true,
|
||||
false,
|
||||
float>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
}
|
||||
} else {
|
||||
if (use_e5m2) {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e5m2_t,
|
||||
false,
|
||||
false,
|
||||
float>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
return dispatch_fp8_rowwise_kernel<
|
||||
cutlass::float_e4m3_t,
|
||||
false,
|
||||
false,
|
||||
float>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else // BUILD_ROWWISE_FP8_KERNEL
|
||||
TORCH_CHECK(false, "Rowwise scaling is not currenlty supported on your device");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at::cuda::detail
|
||||
15
aten/src/ATen/native/cuda/RowwiseScaledMM.h
Normal file
15
aten/src/ATen/native/cuda/RowwiseScaledMM.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
#include <ATen/core/TensorBase.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
|
||||
namespace at::cuda::detail {
|
||||
TORCH_API void f8f8bf16_rowwise(
|
||||
at::Tensor XQ, // FP8
|
||||
at::Tensor WQ, // FP8
|
||||
at::Tensor x_scale, // FP32
|
||||
at::Tensor w_scale, // FP32
|
||||
c10::optional<at::Tensor> bias, // BF16
|
||||
bool use_fast_accum,
|
||||
at::Tensor& out);
|
||||
} // at::cuda::detail
|
||||
@ -14,6 +14,7 @@ using namespace at::cuda::detail;
|
||||
|
||||
// Kernel for fast unfold+copy on volumes
|
||||
template <typename T>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
__global__ void vol2col_kernel(
|
||||
const int64_t n,
|
||||
const T* data_vol,
|
||||
|
||||
@ -614,8 +614,6 @@ void add_projection_weights(
|
||||
/*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(),
|
||||
/*linLayerMat=*/&matrix_pointer));
|
||||
#else
|
||||
void* unused_pointer;
|
||||
TensorDescriptor unused_desc;
|
||||
TensorDescriptor lin_layer_mat_desc;
|
||||
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
|
||||
/*handle=*/handle,
|
||||
@ -626,8 +624,8 @@ void add_projection_weights(
|
||||
/*linLayerID=*/linear_id,
|
||||
/*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(),
|
||||
/*linLayerMat=*/&matrix_pointer,
|
||||
unused_desc.mut_desc(),
|
||||
&unused_pointer));
|
||||
nullptr,
|
||||
nullptr));
|
||||
#endif
|
||||
|
||||
cudnnDataType_t data_type;
|
||||
@ -735,8 +733,6 @@ get_parameters(
|
||||
lin_layer_mat_desc.mut_desc(),
|
||||
&matrix_pointer));
|
||||
#else
|
||||
void* unused_pointer = nullptr;
|
||||
TensorDescriptor unused_desc;
|
||||
TensorDescriptor lin_layer_mat_desc;
|
||||
for (int stateless = 0; stateless < 100; stateless++) {
|
||||
if (cudnn_method) { // matrix
|
||||
@ -749,8 +745,8 @@ get_parameters(
|
||||
linear_id,
|
||||
lin_layer_mat_desc.mut_desc(),
|
||||
&matrix_pointer,
|
||||
unused_desc.mut_desc(),
|
||||
&unused_pointer));
|
||||
nullptr,
|
||||
nullptr));
|
||||
} else { // bias
|
||||
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
|
||||
handle,
|
||||
@ -759,8 +755,8 @@ get_parameters(
|
||||
weight_buf.numel() * weight_buf.element_size(),
|
||||
weight_buf.data_ptr(),
|
||||
linear_id,
|
||||
unused_desc.mut_desc(),
|
||||
&unused_pointer,
|
||||
nullptr,
|
||||
nullptr,
|
||||
lin_layer_mat_desc.mut_desc(),
|
||||
&matrix_pointer));
|
||||
}
|
||||
@ -922,8 +918,6 @@ std::vector<void*> get_expected_data_ptrs(
|
||||
lin_layer_mat_desc.mut_desc(),
|
||||
&matrix_pointer));
|
||||
#else
|
||||
void* unused_pointer = nullptr;
|
||||
TensorDescriptor unused_desc;
|
||||
TensorDescriptor lin_layer_mat_desc;
|
||||
if (cudnn_method) { // matrix
|
||||
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
|
||||
@ -935,8 +929,8 @@ std::vector<void*> get_expected_data_ptrs(
|
||||
linear_id,
|
||||
lin_layer_mat_desc.mut_desc(),
|
||||
&matrix_pointer,
|
||||
unused_desc.mut_desc(),
|
||||
&unused_pointer));
|
||||
nullptr,
|
||||
nullptr));
|
||||
} else { // bias
|
||||
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
|
||||
handle,
|
||||
@ -945,8 +939,8 @@ std::vector<void*> get_expected_data_ptrs(
|
||||
weight_buf.numel() * weight_buf.element_size(),
|
||||
weight_buf.data_ptr(),
|
||||
linear_id,
|
||||
unused_desc.mut_desc(),
|
||||
&unused_pointer,
|
||||
nullptr,
|
||||
nullptr,
|
||||
lin_layer_mat_desc.mut_desc(),
|
||||
&matrix_pointer));
|
||||
}
|
||||
@ -972,8 +966,6 @@ std::vector<void*> get_expected_data_ptrs(
|
||||
lin_layer_mat_desc.mut_desc(),
|
||||
&matrix_pointer));
|
||||
#else
|
||||
void* unused_pointer;
|
||||
TensorDescriptor unused_desc;
|
||||
TensorDescriptor lin_layer_mat_desc;
|
||||
|
||||
AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
|
||||
@ -985,8 +977,8 @@ std::vector<void*> get_expected_data_ptrs(
|
||||
linear_id,
|
||||
lin_layer_mat_desc.mut_desc(),
|
||||
&matrix_pointer,
|
||||
unused_desc.mut_desc(),
|
||||
&unused_pointer));
|
||||
nullptr,
|
||||
nullptr));
|
||||
#endif
|
||||
data_ptrs.push_back(matrix_pointer);
|
||||
}
|
||||
|
||||
@ -421,17 +421,6 @@ TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("mkl::_mkl_linear"), TORCH_FN(mkl_linear));
|
||||
}
|
||||
|
||||
#else // AT_MKL_ENABLED
|
||||
|
||||
static Tensor mkl_linear(
|
||||
const Tensor& self,
|
||||
const Tensor& mkl_weight_t,
|
||||
const Tensor& origin_weight_t,
|
||||
const std::optional<Tensor>& bias_opt,
|
||||
const int64_t prepack_batch_size) {
|
||||
TORCH_CHECK(false, "mkl_linear: ATen not compiled with MKL support");
|
||||
}
|
||||
|
||||
#endif// AT_MKL_ENABLED
|
||||
|
||||
TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
|
||||
|
||||
@ -336,25 +336,34 @@ inline bool is_dense_in_storage(const at::Tensor& t) {
|
||||
|
||||
class MetalShaderLibrary {
|
||||
public:
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_ = 0): shaderSource(src), nparams(nparams_) {}
|
||||
MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
|
||||
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
|
||||
inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
|
||||
return getLibraryPipelineState(getLibrary(), fname);
|
||||
return getLibraryPipelineState(getLibrary(), fname).first;
|
||||
}
|
||||
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
|
||||
return getLibraryPipelineState(getLibrary(params), fname);
|
||||
return getLibraryPipelineState(getLibrary(params), fname).first;
|
||||
}
|
||||
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
|
||||
return getLibraryPipelineState(getLibrary(), fname).second;
|
||||
}
|
||||
id<MTLFunction> getMTLFunction(const std::string& fname, const std::initializer_list<std::string>& params) {
|
||||
return getLibraryPipelineState(getLibrary(params), fname).second;
|
||||
}
|
||||
private:
|
||||
id<MTLComputePipelineState> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
|
||||
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
|
||||
id<MTLLibrary> getLibrary();
|
||||
id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
|
||||
|
||||
id<MTLLibrary> compileLibrary(const std::string& src);
|
||||
std::string shaderSource;
|
||||
unsigned nparams;
|
||||
MTLCompileOptions* compile_options;
|
||||
id<MTLLibrary> library = nil;
|
||||
std::unordered_map<std::string, id<MTLLibrary>> libMap;
|
||||
std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
|
||||
std::unordered_map<std::string, std::pair<id<MTLComputePipelineState>, id<MTLFunction>>> cplMap;
|
||||
};
|
||||
|
||||
static inline void mtl_setBuffer(id<MTLComputeCommandEncoder> encoder, const Tensor& t, unsigned idx) {
|
||||
|
||||
@ -656,31 +656,38 @@ id<MTLLibrary> MetalShaderLibrary::getLibrary(const std::initializer_list<std::s
|
||||
|
||||
id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
|
||||
NSError* error = nil;
|
||||
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
|
||||
: MTLLanguageVersion2_3];
|
||||
// [options setFastMathEnabled: NO];
|
||||
auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
|
||||
MTLCompileOptions* options = compile_options;
|
||||
if (!options) {
|
||||
options = [[MTLCompileOptions new] autorelease];
|
||||
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
|
||||
: MTLLanguageVersion2_3];
|
||||
[options setFastMathEnabled:NO];
|
||||
}
|
||||
|
||||
const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
|
||||
auto device = MPSDevice::getInstance()->device();
|
||||
library = [device newLibraryWithSource:str options:options error:&error];
|
||||
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
|
||||
return library;
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> MetalShaderLibrary::getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname) {
|
||||
auto key = fmt::format("{}:{}", reinterpret_cast<void*>(lib), fname);
|
||||
auto cpl = cplMap[key];
|
||||
if (cpl) {
|
||||
return cpl;
|
||||
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> MetalShaderLibrary::getLibraryPipelineState(
|
||||
id<MTLLibrary> lib,
|
||||
const std::string& fname) {
|
||||
const auto key = fmt::format("{}:{}", reinterpret_cast<void*>(lib), fname);
|
||||
auto found_cpl = cplMap.find(key);
|
||||
if (found_cpl != cplMap.end()) {
|
||||
return found_cpl->second;
|
||||
}
|
||||
|
||||
NSError* error = nil;
|
||||
id<MTLFunction> func = [lib newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
|
||||
TORCH_CHECK(func, "Failed to create function state object for: ", fname);
|
||||
cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error];
|
||||
auto cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error];
|
||||
TORCH_CHECK(cpl, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
|
||||
|
||||
return cplMap[key] = cpl;
|
||||
cplMap[key] = std::make_pair(cpl, func);
|
||||
return cplMap[key];
|
||||
}
|
||||
|
||||
} // namespace at::native::mps
|
||||
|
||||
@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void _fused_adam_amsgrad_mps_impl_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf
|
||||
);
|
||||
} //namespace mps
|
||||
}// namespace at::native
|
||||
@ -0,0 +1,37 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
|
||||
#include <ATen/native/mps/operations/MultiTensorApply.h>
|
||||
#include <vector>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void _fused_adam_amsgrad_mps_impl_(at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists{
|
||||
params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()};
|
||||
|
||||
const std::string kernel_name = "fused_adam_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" +
|
||||
scalarToMetalTypeString(state_steps[0].scalar_type());
|
||||
|
||||
multi_tensor_apply_for_fused_adam<5, 512>(
|
||||
kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize);
|
||||
}
|
||||
} // namespace mps
|
||||
} // namespace at::native
|
||||
69
aten/src/ATen/native/mps/operations/FusedAdamKernel.mm
Normal file
69
aten/src/ATen/native/mps/operations/FusedAdamKernel.mm
Normal file
@ -0,0 +1,69 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/TypeDefault.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h>
|
||||
#include <ATen/native/mps/operations/FusedAdamKernelImpl.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_fused_adam_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void _fused_adam_kernel_mps_(at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool amsgrad,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
if (amsgrad) {
|
||||
TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
|
||||
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
|
||||
mps::_fused_adam_amsgrad_mps_impl_(params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale,
|
||||
found_inf);
|
||||
} else {
|
||||
TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
|
||||
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
|
||||
mps::_fused_adam_mps_impl_(params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale,
|
||||
found_inf);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
23
aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h
Normal file
23
aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.h
Normal file
@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void _fused_adam_mps_impl_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf
|
||||
);
|
||||
} //namespace mps
|
||||
}// namespace at::native
|
||||
35
aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm
Normal file
35
aten/src/ATen/native/mps/operations/FusedAdamKernelImpl.mm
Normal file
@ -0,0 +1,35 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/mps/operations/FusedAdamKernelImpl.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
|
||||
#include <ATen/native/mps/operations/MultiTensorApply.h>
|
||||
#include <vector>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void _fused_adam_mps_impl_(at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
|
||||
|
||||
const std::string kernel_name = "fused_adam_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" +
|
||||
scalarToMetalTypeString(state_steps[0].scalar_type());
|
||||
|
||||
multi_tensor_apply_for_fused_adam<4, 512>(
|
||||
kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize);
|
||||
}
|
||||
} // namespace mps
|
||||
} // namespace at::native
|
||||
@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void _fused_adamw_amsgrad_mps_impl_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf
|
||||
);
|
||||
} //namespace mps
|
||||
}// namespace at::native
|
||||
@ -0,0 +1,37 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
|
||||
#include <ATen/native/mps/operations/MultiTensorApply.h>
|
||||
#include <vector>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void _fused_adamw_amsgrad_mps_impl_(at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists{
|
||||
params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec(), max_exp_avg_sqs.vec()};
|
||||
|
||||
const std::string kernel_name = "fused_adamw_amsgrad_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" +
|
||||
scalarToMetalTypeString(state_steps[0].scalar_type());
|
||||
|
||||
multi_tensor_apply_for_fused_adam<5, 512>(
|
||||
kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize);
|
||||
}
|
||||
} // namespace mps
|
||||
} // namespace at::native
|
||||
68
aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm
Normal file
68
aten/src/ATen/native/mps/operations/FusedAdamWKernel.mm
Normal file
@ -0,0 +1,68 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/TypeDefault.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h>
|
||||
#include <ATen/native/mps/operations/FusedAdamWKernelImpl.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <iostream>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_fused_adamw_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void _fused_adamw_kernel_mps_(at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList max_exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool amsgrad,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
if (amsgrad) {
|
||||
TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
|
||||
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
|
||||
mps::_fused_adamw_amsgrad_mps_impl_(params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale,
|
||||
found_inf);
|
||||
} else {
|
||||
TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads, exp_avgs, exp_avg_sqs}),
|
||||
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
|
||||
mps::_fused_adamw_mps_impl_(params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale,
|
||||
found_inf);
|
||||
}
|
||||
}
|
||||
} // namespace at::native
|
||||
23
aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h
Normal file
23
aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.h
Normal file
@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void _fused_adamw_mps_impl_(
|
||||
at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf
|
||||
);
|
||||
} //namespace mps
|
||||
}// namespace at::native
|
||||
35
aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm
Normal file
35
aten/src/ATen/native/mps/operations/FusedAdamWKernelImpl.mm
Normal file
@ -0,0 +1,35 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/mps/operations/FusedAdamWKernelImpl.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/ForeachUtils.h>
|
||||
#include <ATen/native/mps/operations/FusedOptimizerOps.h>
|
||||
#include <ATen/native/mps/operations/MultiTensorApply.h>
|
||||
#include <vector>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
void _fused_adamw_mps_impl_(at::TensorList params,
|
||||
at::TensorList grads,
|
||||
at::TensorList exp_avgs,
|
||||
at::TensorList exp_avg_sqs,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize,
|
||||
const c10::optional<at::Tensor>& grad_scale,
|
||||
const c10::optional<at::Tensor>& found_inf) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec(), exp_avgs.vec(), exp_avg_sqs.vec()};
|
||||
|
||||
const std::string kernel_name = "fused_adamw_" + scalarToMetalTypeString(params[0].scalar_type()) + "_" +
|
||||
scalarToMetalTypeString(state_steps[0].scalar_type());
|
||||
|
||||
multi_tensor_apply_for_fused_adam<4, 512>(
|
||||
kernel_name, tensor_lists, state_steps, lr, beta1, beta2, weight_decay, eps, maximize);
|
||||
}
|
||||
} // namespace mps
|
||||
} // namespace at::native
|
||||
274
aten/src/ATen/native/mps/operations/FusedOptimizerOps.h
Normal file
274
aten/src/ATen/native/mps/operations/FusedOptimizerOps.h
Normal file
@ -0,0 +1,274 @@
|
||||
#pragma once
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
static const char* FUSED_ADAM_OPS = R"METAL(
|
||||
#include <metal_stdlib>
|
||||
|
||||
#define kmaxThreadGroups 32
|
||||
#define kmaxTensors 32
|
||||
#define chunk_size 65536
|
||||
|
||||
constexpr constant uint kParamIdx = 0;
|
||||
constexpr constant uint kGradIdx = kParamIdx + kmaxTensors;
|
||||
constexpr constant uint kExpAvgIdx = kGradIdx + kmaxTensors;
|
||||
constexpr constant uint kExpAvgSqIdx = kExpAvgIdx + kmaxTensors;
|
||||
constexpr constant uint kMaxExpAvgSqIdx = kExpAvgSqIdx + kmaxTensors;
|
||||
constexpr constant uint kStateStepsIdx = kExpAvgSqIdx + kmaxTensors;
|
||||
constexpr constant uint kStateStepsIdxForAmsgrad = kMaxExpAvgSqIdx + kmaxTensors;
|
||||
|
||||
template<typename T, typename state_steps_t>
|
||||
struct AdamArguments {
|
||||
metal::array<device T *, kmaxTensors> params [[ id(kParamIdx) ]];
|
||||
metal::array<device T *, kmaxTensors> grads [[ id(kGradIdx) ]];
|
||||
metal::array<device T *, kmaxTensors> exp_avgs [[ id(kExpAvgIdx) ]];
|
||||
metal::array<device T *, kmaxTensors> exp_avg_sqs [[ id(kExpAvgSqIdx) ]];
|
||||
metal::array<device state_steps_t *, kmaxTensors> state_steps [[ id(kStateStepsIdx) ]];
|
||||
};
|
||||
|
||||
template<typename T, typename state_steps_t>
|
||||
struct AdamAmsgradArguments {
|
||||
metal::array<device T *, kmaxTensors> params [[ id(kParamIdx) ]];
|
||||
metal::array<device T *, kmaxTensors> grads [[ id(kGradIdx) ]];
|
||||
metal::array<device T *, kmaxTensors> exp_avgs [[ id(kExpAvgIdx) ]];
|
||||
metal::array<device T *, kmaxTensors> exp_avg_sqs [[ id(kExpAvgSqIdx) ]];
|
||||
metal::array<device T *, kmaxTensors> max_exp_avg_sqs [[ id(kMaxExpAvgSqIdx) ]];
|
||||
metal::array<device state_steps_t *, kmaxTensors> state_steps [[ id(kStateStepsIdxForAmsgrad) ]];
|
||||
};
|
||||
|
||||
struct MetadataArguments {
|
||||
uint32_t numels[kmaxTensors];
|
||||
uint32_t threadgroup_to_tensor[kmaxThreadGroups];
|
||||
uint32_t threadgroup_to_chunk[kmaxThreadGroups];
|
||||
};
|
||||
|
||||
enum ADAM_MODE : uint8_t {
|
||||
ORIGINAL = 0,
|
||||
ADAMW = 1
|
||||
};
|
||||
|
||||
template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
|
||||
inline void adam_math_amsgrad(
|
||||
device T & param,
|
||||
device T & grad,
|
||||
device T & exp_avg,
|
||||
device T & exp_avg_sq,
|
||||
device T & max_exp_avg_sq,
|
||||
device state_steps_t & state_steps,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float weight_decay,
|
||||
const float eps,
|
||||
const uint8_t maximize
|
||||
) {
|
||||
T grad_ = grad;
|
||||
|
||||
if (maximize) {
|
||||
grad = -grad;
|
||||
}
|
||||
|
||||
// Update param, grad, 1st and 2nd order momentum.
|
||||
if (weight_decay != 0) {
|
||||
switch (adam_mode) {
|
||||
case ADAM_MODE::ORIGINAL:
|
||||
grad += param * weight_decay;
|
||||
break;
|
||||
case ADAM_MODE::ADAMW:
|
||||
param -= lr * weight_decay * param;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
|
||||
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
|
||||
const float casted_state_steps = static_cast<float>(state_steps);
|
||||
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
|
||||
const T step_size = lr / bias_correction1;
|
||||
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
|
||||
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
|
||||
max_exp_avg_sq = metal::max(max_exp_avg_sq, exp_avg_sq);
|
||||
|
||||
const T denom = (metal::precise::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
|
||||
param -= step_size * exp_avg / denom;
|
||||
grad = grad_;
|
||||
}
|
||||
|
||||
template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
|
||||
inline void adam_math(
|
||||
device T & param,
|
||||
device T & grad,
|
||||
device T & exp_avg,
|
||||
device T & exp_avg_sq,
|
||||
device state_steps_t & state_steps,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float weight_decay,
|
||||
const float eps,
|
||||
const uint8_t maximize
|
||||
) {
|
||||
T grad_ = grad;
|
||||
|
||||
if (maximize) {
|
||||
grad = -grad;
|
||||
}
|
||||
|
||||
// Update param, grad, 1st and 2nd order momentum.
|
||||
if (weight_decay != 0) {
|
||||
switch (adam_mode) {
|
||||
case ADAM_MODE::ORIGINAL:
|
||||
grad += param * weight_decay;
|
||||
break;
|
||||
case ADAM_MODE::ADAMW:
|
||||
param -= lr * weight_decay * param;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
|
||||
exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
|
||||
const float casted_state_steps = static_cast<float>(state_steps);
|
||||
const T bias_correction1 = 1 - metal::precise::pow(beta1, casted_state_steps);
|
||||
const T step_size = lr / bias_correction1;
|
||||
const T bias_correction2 = 1 - metal::precise::pow(beta2, casted_state_steps);
|
||||
const T bias_correction2_sqrt = metal::precise::sqrt(bias_correction2);
|
||||
const T denom = (metal::precise::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
|
||||
param -= step_size * exp_avg / denom;
|
||||
grad = grad_;
|
||||
}
|
||||
|
||||
template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
|
||||
kernel void fused_adam_amsgrad(
|
||||
device AdamAmsgradArguments<T, state_steps_t> & args [[buffer(0)]],
|
||||
constant MetadataArguments & metadata_args [[buffer(1)]],
|
||||
constant float & lr [[buffer(2)]],
|
||||
constant float & beta1 [[buffer(3)]],
|
||||
constant float & beta2 [[buffer(4)]],
|
||||
constant float & weight_decay [[buffer(5)]],
|
||||
constant float & eps [[buffer(6)]],
|
||||
constant uint8_t & maximize [[buffer(7)]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint tgid [[threadgroup_position_in_grid]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
|
||||
const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
|
||||
const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
|
||||
const uint32_t chunk_offset = chunk_idx * chunk_size;
|
||||
const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
|
||||
|
||||
const auto step_count = args.state_steps[tensor_loc];
|
||||
|
||||
// each chunk is a threadgroup
|
||||
auto param = args.params[tensor_loc] + chunk_offset;
|
||||
auto grad = args.grads[tensor_loc] + chunk_offset;
|
||||
auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset;
|
||||
auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset;
|
||||
auto max_exp_avg_sq = args.max_exp_avg_sqs[tensor_loc] + chunk_offset;
|
||||
|
||||
for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
|
||||
adam_math_amsgrad<T, state_steps_t, adam_mode>(
|
||||
*(param + i_start),
|
||||
*(grad + i_start),
|
||||
*(exp_avg + i_start),
|
||||
*(exp_avg_sq + i_start),
|
||||
*(max_exp_avg_sq + i_start),
|
||||
*step_count,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename state_steps_t, ADAM_MODE adam_mode>
|
||||
kernel void fused_adam(
|
||||
device AdamArguments<T, state_steps_t> & args [[buffer(0)]],
|
||||
constant MetadataArguments & metadata_args [[buffer(1)]],
|
||||
constant float & lr [[buffer(2)]],
|
||||
constant float & beta1 [[buffer(3)]],
|
||||
constant float & beta2 [[buffer(4)]],
|
||||
constant float & weight_decay [[buffer(5)]],
|
||||
constant float & eps [[buffer(6)]],
|
||||
constant uint8_t & maximize [[buffer(7)]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint tgid [[threadgroup_position_in_grid]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
|
||||
const uint32_t tensor_loc = metadata_args.threadgroup_to_tensor[tgid];
|
||||
const uint32_t chunk_idx = metadata_args.threadgroup_to_chunk[tgid];
|
||||
const uint32_t chunk_offset = chunk_idx * chunk_size;
|
||||
const uint32_t numel = metadata_args.numels[tensor_loc] - chunk_offset;
|
||||
|
||||
const auto step_count = args.state_steps[tensor_loc];
|
||||
|
||||
// each chunk is a threadgroup
|
||||
auto param = args.params[tensor_loc] + chunk_offset;
|
||||
auto grad = args.grads[tensor_loc] + chunk_offset;
|
||||
auto exp_avg = args.exp_avgs[tensor_loc] + chunk_offset;
|
||||
auto exp_avg_sq = args.exp_avg_sqs[tensor_loc] + chunk_offset;
|
||||
|
||||
for (uint32_t i_start = tid; i_start < numel && i_start < chunk_size; i_start += tptg) {
|
||||
adam_math<T, state_steps_t, adam_mode>(
|
||||
*(param + i_start),
|
||||
*(grad + i_start),
|
||||
*(exp_avg + i_start),
|
||||
*(exp_avg_sq + i_start),
|
||||
*step_count,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_FUSED_ADAM_OP(DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE, HOST_NAME, KERNEL_NAME, ARGUMENTS_STRUCT) \
|
||||
template \
|
||||
[[host_name(#HOST_NAME "_" #DTYPE "_" #STATE_STEPS_DTYPE)]] \
|
||||
kernel void KERNEL_NAME<DTYPE, STATE_STEPS_DTYPE, ADAM_MODE_DTYPE>( \
|
||||
device ARGUMENTS_STRUCT<DTYPE, STATE_STEPS_DTYPE> & args [[buffer(0)]],\
|
||||
constant MetadataArguments & metadata_args [[buffer(1)]],\
|
||||
constant float & lr [[buffer(2)]],\
|
||||
constant float & beta1 [[buffer(3)]],\
|
||||
constant float & beta2 [[buffer(4)]],\
|
||||
constant float & weight_decay [[buffer(5)]],\
|
||||
constant float & eps [[buffer(6)]],\
|
||||
constant uint8_t & maximize [[buffer(7)]],\
|
||||
uint tid [[thread_position_in_threadgroup]],\
|
||||
uint tgid [[threadgroup_position_in_grid]],\
|
||||
uint tptg [[threads_per_threadgroup]])
|
||||
|
||||
REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
|
||||
REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
|
||||
REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
|
||||
REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam, fused_adam, AdamArguments);
|
||||
REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
|
||||
REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
|
||||
REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
|
||||
REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw, fused_adam, AdamArguments);
|
||||
REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
|
||||
REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
|
||||
REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
|
||||
REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ORIGINAL, fused_adam_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
|
||||
REGISTER_FUSED_ADAM_OP(float, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
|
||||
REGISTER_FUSED_ADAM_OP(float, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
|
||||
REGISTER_FUSED_ADAM_OP(half, float, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
|
||||
REGISTER_FUSED_ADAM_OP(half, half, ADAM_MODE::ADAMW, fused_adamw_amsgrad, fused_adam_amsgrad, AdamAmsgradArguments);
|
||||
|
||||
)METAL";
|
||||
|
||||
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(const std::string& fname) {
|
||||
static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0);
|
||||
return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
|
||||
}
|
||||
|
||||
} //namespace mps
|
||||
} // namespace at::native
|
||||
@ -17,11 +17,15 @@
|
||||
#include <ATen/ops/addr_native.h>
|
||||
#include <ATen/ops/baddbmm_native.h>
|
||||
#include <ATen/ops/bmm_native.h>
|
||||
#include <ATen/ops/linalg_lu_factor_native.h>
|
||||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/stack.h>
|
||||
#include <ATen/ops/triangular_solve_native.h>
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
namespace {
|
||||
@ -127,6 +131,116 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output)
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
|
||||
"linalg.lu_factor(): MPS doesn't support complex types.");
|
||||
TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False.");
|
||||
|
||||
Tensor A_t = A;
|
||||
uint64_t aRows = A_t.size(-2);
|
||||
uint64_t aCols = A_t.size(-1);
|
||||
uint64_t aElemSize = A_t.element_size();
|
||||
uint64_t numPivots = std::min(aRows, aCols);
|
||||
std::vector<int64_t> pivot_sizes(A_t.sizes().begin(), A_t.sizes().end() - 2);
|
||||
pivot_sizes.push_back(numPivots);
|
||||
resize_output(pivots, pivot_sizes);
|
||||
|
||||
if (A_t.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor A_ = A_t.dim() > 3 ? A_t.flatten(0, -3) : A_t;
|
||||
|
||||
uint64_t batchSize = A_.dim() > 2 ? A_.size(0) : 1;
|
||||
std::vector<Tensor> status_tensors;
|
||||
std::vector<Tensor> pivots_list;
|
||||
|
||||
status_tensors.reserve(batchSize);
|
||||
pivots_list.reserve(batchSize);
|
||||
for (C10_UNUSED const auto i : c10::irange(batchSize)) {
|
||||
status_tensors.push_back(at::zeros(1, kInt, c10::nullopt, kMPS, c10::nullopt));
|
||||
pivots_list.push_back(at::zeros(numPivots, kInt, c10::nullopt, kMPS, c10::nullopt));
|
||||
}
|
||||
|
||||
// Since the MPSMatrixDecompositionLU functions in-place if the result matrix completely aliases the source matrix,
|
||||
// We copy LU from A as the new A.
|
||||
resize_output(LU, A_.sizes());
|
||||
if (!LU.is_same(A_)) {
|
||||
A_ = LU.copy_(A_);
|
||||
} else {
|
||||
A_ = LU;
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(A_.is_contiguous())
|
||||
|
||||
id<MTLBuffer> aBuffer = getMTLBufferStorage(A_);
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
|
||||
MPSMatrixDecompositionLU* filter = [[[MPSMatrixDecompositionLU alloc] initWithDevice:device
|
||||
rows:aRows
|
||||
columns:aCols] autorelease];
|
||||
|
||||
MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows
|
||||
columns:aCols
|
||||
matrices:batchSize
|
||||
rowBytes:aCols * aElemSize
|
||||
matrixBytes:aRows * aCols * aElemSize
|
||||
dataType:getMPSDataType(A_)];
|
||||
MPSMatrixDescriptor* pivotsMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:1
|
||||
columns:numPivots
|
||||
matrices:1
|
||||
rowBytes:numPivots * sizeof(uint32_t)
|
||||
matrixBytes:numPivots * sizeof(uint32_t)
|
||||
dataType:MPSDataTypeUInt32];
|
||||
|
||||
for (const auto i : c10::irange(batchSize)) {
|
||||
const uint64_t aBatchOffset = i * aRows * aCols;
|
||||
MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
|
||||
offset:(A_.storage_offset() + aBatchOffset) * aElemSize
|
||||
descriptor:sourceMatrixDesc] autorelease];
|
||||
MPSMatrix* pivotIndices = [[[MPSMatrix alloc] initWithBuffer:getMTLBufferStorage(pivots_list[i])
|
||||
offset:0
|
||||
descriptor:pivotsMatrixDesc] autorelease];
|
||||
MPSMatrix* solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer
|
||||
offset:(A_.storage_offset() + aBatchOffset) * aElemSize
|
||||
descriptor:sourceMatrixDesc] autorelease];
|
||||
id<MTLBuffer> statusBuffer = getMTLBufferStorage(status_tensors[i]);
|
||||
[filter encodeToCommandBuffer:commandBuffer
|
||||
sourceMatrix:sourceMatrix
|
||||
resultMatrix:solutionMatrix
|
||||
pivotIndices:pivotIndices
|
||||
status:statusBuffer];
|
||||
}
|
||||
}
|
||||
});
|
||||
auto stacked_pivots = A_.dim() > 2 ? at::stack(pivots_list) : pivots_list[0];
|
||||
if (A_t.dim() > 3) {
|
||||
resize_output(LU, A_t.sizes());
|
||||
pivots.copy_(stacked_pivots.view(pivot_sizes));
|
||||
} else {
|
||||
pivots.copy_(stacked_pivots);
|
||||
}
|
||||
pivots += 1; // PyTorch's `pivots` is 1-index.
|
||||
|
||||
for (const auto i : c10::irange(status_tensors.size())) {
|
||||
int status = status_tensors[i].item<int>();
|
||||
TORCH_CHECK(
|
||||
status == 0,
|
||||
"lu_factor(): LU factorization failure at the ",
|
||||
i + 1,
|
||||
" sample with status: ",
|
||||
status,
|
||||
". See https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus for details.");
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSBinaryCachedGraph;
|
||||
@ -753,4 +867,16 @@ TORCH_IMPL_FUNC(triangular_solve_mps_out)
|
||||
result.copy_(out);
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&> linalg_lu_factor_out_mps(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
|
||||
mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots);
|
||||
return std::tie(LU, pivots);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> linalg_lu_factor_mps(const Tensor& A, bool pivot) {
|
||||
Tensor LU = at::empty({0}, A.options());
|
||||
Tensor pivots = at::empty({0}, A.options().dtype(kInt));
|
||||
mps::linalg_lu_factor_out_mps_impl(A, pivot, LU, pivots);
|
||||
return std::make_tuple(std::move(LU), std::move(pivots));
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
190
aten/src/ATen/native/mps/operations/MultiTensorApply.h
Normal file
190
aten/src/ATen/native/mps/operations/MultiTensorApply.h
Normal file
@ -0,0 +1,190 @@
|
||||
#pragma once
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <Aten/native/mps/operations/FusedOptimizerOps.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
static constexpr int64_t kChunkSize = 65536;
|
||||
static constexpr int64_t kmaxThreadGroups = 32;
|
||||
static constexpr int64_t kmaxTensors = 32;
|
||||
|
||||
struct MetadataArguments { // the size of this struct must be less than 4 bytes
|
||||
uint numels[kmaxTensors];
|
||||
uint threadgroup_to_tensor[kmaxThreadGroups];
|
||||
uint threadgroup_to_chunk[kmaxThreadGroups];
|
||||
};
|
||||
|
||||
template <int depth, uint32_t kThreadGroupSize>
|
||||
static void multi_tensor_apply_for_fused_adam(
|
||||
const std::string& kernel_name,
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
at::TensorList state_steps,
|
||||
const double lr,
|
||||
const double beta1,
|
||||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize
|
||||
) {
|
||||
const auto num_tensors = tensor_lists[0].size();
|
||||
|
||||
if (num_tensors == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
tensor_lists.size() == depth,
|
||||
"Number of tensor lists has to match the depth");
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
TORCH_CHECK(
|
||||
tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported");
|
||||
}
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
|
||||
float lr_lv = lr;
|
||||
float beta1_lv = beta1;
|
||||
float beta2_lv = beta2;
|
||||
float weight_decay_lv = weight_decay;
|
||||
float eps_lv = eps;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
// Remove comment for debugging
|
||||
/*
|
||||
mpsStream->addCompletedHandler(^(id<MTLCommandBuffer> cb) {
|
||||
[cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) {
|
||||
NSLog(@"MPSStream: %@", log);
|
||||
}
|
||||
];
|
||||
});
|
||||
*/
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
auto [fusedOptimizerPSO, fusedOptimizerFunc] = getCPLState(kernel_name);
|
||||
|
||||
// this function call is a no-op if MPS Profiler is not enabled
|
||||
getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]});
|
||||
|
||||
[computeEncoder setComputePipelineState:fusedOptimizerPSO];
|
||||
|
||||
// BufferIndex is the index in the kernel function
|
||||
auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease];
|
||||
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
|
||||
int64_t tensor_loc = 0;
|
||||
int64_t threadgroup_loc = 0;
|
||||
MetadataArguments metadata_arguments;
|
||||
|
||||
for (const auto tensor_index : c10::irange(num_tensors)) {
|
||||
// short-circuit to avoid adding empty tensors to tensorListMeta
|
||||
if (tensor_lists[0][tensor_index].numel() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
[tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index])
|
||||
offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size()
|
||||
atIndex:d * kmaxTensors + tensor_loc];
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite];
|
||||
}
|
||||
[tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index])
|
||||
offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size()
|
||||
atIndex:depth * kmaxTensors + tensor_loc];
|
||||
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
|
||||
metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel();
|
||||
|
||||
tensor_loc++;
|
||||
|
||||
const auto numel = tensor_lists[0][tensor_index].numel();
|
||||
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
|
||||
TORCH_CHECK(chunks > -1);
|
||||
|
||||
for (const auto& chunk : c10::irange(chunks)) {
|
||||
metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1;
|
||||
metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk;
|
||||
|
||||
threadgroup_loc++;
|
||||
|
||||
const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1;
|
||||
// Reach the maximum threadgroups per dispatch
|
||||
const auto blocks_full = threadgroup_loc == kmaxThreadGroups;
|
||||
|
||||
if (tensor_full || blocks_full){
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
[computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2];
|
||||
[computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3];
|
||||
[computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4];
|
||||
[computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5];
|
||||
[computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6];
|
||||
[computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7];
|
||||
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
|
||||
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
|
||||
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
|
||||
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
|
||||
// Reset
|
||||
threadgroup_loc = 0;
|
||||
if (chunk == chunks - 1) {
|
||||
// last chunk
|
||||
tensor_loc = 0;
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
} else {
|
||||
// reuse the current tensor since the current one isn't done.
|
||||
metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
|
||||
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
[tensorArgumentEncoder setBuffer:getMTLBufferStorage(tensor_lists[d][tensor_index])
|
||||
offset:tensor_lists[d][tensor_index].storage_offset() * tensor_lists[d][tensor_index].element_size()
|
||||
atIndex:d * kmaxTensors + 0];
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead];
|
||||
}
|
||||
[tensorArgumentEncoder setBuffer:getMTLBufferStorage(state_steps[tensor_index])
|
||||
offset:state_steps[tensor_index].storage_offset() * state_steps[tensor_index].element_size()
|
||||
atIndex:depth * kmaxTensors + 0];
|
||||
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
|
||||
|
||||
tensor_loc = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (threadgroup_loc != 0) {
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
[computeEncoder setBytes:&lr_lv length:sizeof(float) atIndex:2];
|
||||
[computeEncoder setBytes:&beta1_lv length:sizeof(float) atIndex:3];
|
||||
[computeEncoder setBytes:&beta2_lv length:sizeof(float) atIndex:4];
|
||||
[computeEncoder setBytes:&weight_decay_lv length:sizeof(float) atIndex:5];
|
||||
[computeEncoder setBytes:&eps_lv length:sizeof(float) atIndex:6];
|
||||
[computeEncoder setBytes:&maximize_lv length:sizeof(uint8_t) atIndex:7];
|
||||
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
|
||||
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
|
||||
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
|
||||
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
|
||||
}
|
||||
|
||||
getMPSProfiler().endProfileKernel(fusedOptimizerPSO);
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
} // namespace at::native
|
||||
@ -6185,12 +6185,12 @@
|
||||
CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy
|
||||
autogen: _nested_view_from_buffer_copy.out
|
||||
|
||||
- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a)
|
||||
- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a)
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
dispatch: {}
|
||||
|
||||
- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor
|
||||
- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
tags: view_copy
|
||||
@ -6227,16 +6227,6 @@
|
||||
device_check: NoCheck
|
||||
dispatch: {}
|
||||
|
||||
- func: _nested_get_min_seqlen(Tensor self) -> Tensor
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
dispatch: {}
|
||||
|
||||
- func: _nested_get_max_seqlen(Tensor self) -> Tensor
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
dispatch: {}
|
||||
|
||||
- func: _nested_get_jagged_dummy(Tensor any) -> Tensor
|
||||
category_override: dummy
|
||||
dispatch: {}
|
||||
@ -13797,10 +13787,16 @@
|
||||
- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: linalg_lu_factor
|
||||
MPS: linalg_lu_factor_mps
|
||||
|
||||
- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
|
||||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: linalg_lu_factor_out
|
||||
MPS: linalg_lu_factor_out_mps
|
||||
|
||||
- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
|
||||
python_module: linalg
|
||||
@ -15575,6 +15571,7 @@
|
||||
dispatch:
|
||||
CPU: _fused_adam_kernel_cpu_
|
||||
CUDA: _fused_adam_kernel_cuda_
|
||||
MPS: _fused_adam_kernel_mps_
|
||||
autogen: _fused_adam, _fused_adam.out
|
||||
|
||||
- func: _fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
|
||||
@ -15593,6 +15590,7 @@
|
||||
dispatch:
|
||||
CPU: _fused_adamw_kernel_cpu_
|
||||
CUDA: _fused_adamw_kernel_cuda_
|
||||
MPS: _fused_adamw_kernel_mps_
|
||||
autogen: _fused_adamw, _fused_adamw.out
|
||||
|
||||
- func: _fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
|
||||
|
||||
@ -264,7 +264,7 @@ def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
|
||||
batch_sizes = [2, 8, 16]
|
||||
num_heads = [16]
|
||||
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
|
||||
head_dims = [64, 128, 256]
|
||||
head_dims = [64, 128]
|
||||
dtypes = [
|
||||
torch.bfloat16,
|
||||
]
|
||||
@ -302,8 +302,6 @@ def main(dynamic: bool, calculate_bwd: bool):
|
||||
results.append(
|
||||
Experiment(config, run_single_experiment(config, dynamic=dynamic))
|
||||
)
|
||||
for config in tqdm(generate_experiment_configs(calculate_bwd)):
|
||||
results.append(Experiment(config, run_single_experiment(config)))
|
||||
|
||||
print_results(results)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
|
||||
@ -958,6 +958,10 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
void recordAnnotation(const std::shared_ptr<GatheredContext>& name) {
|
||||
record_trace(TraceEntry::USER_DEFINED, 0, 0, nullptr, 0, name);
|
||||
}
|
||||
|
||||
bool isHistoryEnabled() {
|
||||
return record_history;
|
||||
}
|
||||
@ -3026,6 +3030,12 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
void recordAnnotation(const std::shared_ptr<GatheredContext>& name) override {
|
||||
for (auto& allocator : device_allocator) {
|
||||
allocator->recordAnnotation(name);
|
||||
}
|
||||
}
|
||||
|
||||
bool isHistoryEnabled() override {
|
||||
c10::DeviceIndex device = 0;
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
|
||||
@ -170,8 +170,9 @@ struct TraceEntry {
|
||||
SEGMENT_UNMAP, // unmap part of a segment (used with expandable segments)
|
||||
SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace
|
||||
// events
|
||||
OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free
|
||||
// bytes reported by cuda)
|
||||
OOM, // the allocator threw an OutOfMemoryError (addr_ is the amount of free
|
||||
// bytes reported by cuda)
|
||||
USER_DEFINED // a call made from user defined API such as record_function
|
||||
};
|
||||
TraceEntry(
|
||||
Action action,
|
||||
@ -289,6 +290,7 @@ class CUDAAllocator : public Allocator {
|
||||
CreateContextFn context_recorder,
|
||||
size_t alloc_trace_max_entries,
|
||||
RecordContext when) = 0;
|
||||
virtual void recordAnnotation(const std::shared_ptr<GatheredContext>& name){};
|
||||
virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
|
||||
|
||||
// Attached AllocatorTraceTracker callbacks will be called while the
|
||||
@ -428,6 +430,10 @@ inline void recordHistory(
|
||||
enabled, context_recorder, alloc_trace_max_entries, when);
|
||||
}
|
||||
|
||||
inline void recordAnnotation(const std::shared_ptr<GatheredContext>& name) {
|
||||
return get()->recordAnnotation(name);
|
||||
}
|
||||
|
||||
inline bool isHistoryEnabled() {
|
||||
return get()->isHistoryEnabled();
|
||||
}
|
||||
|
||||
@ -750,6 +750,9 @@ if(BUILD_LIBTORCHLESS)
|
||||
find_library(TORCH_XPU_LIB torch_xpu PATHS $ENV{LIBTORCH_LIB_PATH} NO_DEFAULT_PATH)
|
||||
endif()
|
||||
add_subdirectory(../torch torch)
|
||||
# ---[ Torch python bindings build
|
||||
set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE)
|
||||
set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE)
|
||||
else()
|
||||
set(TORCH_LIB torch)
|
||||
set(TORCH_CPU_LIB torch_cpu)
|
||||
@ -1270,12 +1273,10 @@ install(FILES
|
||||
${PROJECT_BINARY_DIR}/TorchConfig.cmake
|
||||
DESTINATION share/cmake/Torch)
|
||||
|
||||
|
||||
# ---[ Torch python bindings build
|
||||
add_subdirectory(../torch torch)
|
||||
set(TORCH_PYTHON_COMPILE_OPTIONS ${TORCH_PYTHON_COMPILE_OPTIONS} PARENT_SCOPE)
|
||||
set(TORCH_PYTHON_LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS} PARENT_SCOPE)
|
||||
|
||||
# ==========================================================
|
||||
# END formerly-libtorch flags
|
||||
# ==========================================================
|
||||
|
||||
@ -35,3 +35,6 @@ torch.utils.checkpoint
|
||||
.. autofunction:: checkpoint
|
||||
.. autofunction:: checkpoint_sequential
|
||||
.. autofunction:: set_checkpoint_debug_enabled
|
||||
.. autoclass:: CheckpointPolicy
|
||||
.. autoclass:: SelectiveCheckpointContext
|
||||
.. autofunction:: create_selective_checkpoint_contexts
|
||||
|
||||
@ -2713,6 +2713,7 @@ coverage_ignore_classes = [
|
||||
"GuardOnDataDependentSymNode",
|
||||
"PendingUnbackedSymbolNotFound",
|
||||
"LoggingShapeGuardPrinter",
|
||||
"SymExprPrinter",
|
||||
"RelaxedUnspecConstraint",
|
||||
"RuntimeAssert",
|
||||
"ShapeGuardPrinter",
|
||||
|
||||
@ -719,3 +719,5 @@ API Reference
|
||||
:members:
|
||||
|
||||
.. automodule:: torch.export.custom_obj
|
||||
|
||||
.. automodule:: torch.export.experimental
|
||||
|
||||
39
docs/source/mps_environment_variables.rst
Normal file
39
docs/source/mps_environment_variables.rst
Normal file
@ -0,0 +1,39 @@
|
||||
.. _mps_environment_variables:
|
||||
|
||||
MPS Environment Variables
|
||||
==========================
|
||||
|
||||
**PyTorch Environment Variables**
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - Variable
|
||||
- Description
|
||||
* - ``PYTORCH_DEBUG_MPS_ALLOCATOR``
|
||||
- If set to ``1``, set allocator logging level to verbose.
|
||||
* - ``PYTORCH_MPS_HIGH_WATERMARK_RATIO``
|
||||
- High watermark ratio for MPS allocator. By default, it is set to 1.7.
|
||||
* - ``PYTORCH_MPS_LOW_WATERMARK_RATIO``
|
||||
- Low watermark ratio for MPS allocator. By default, it is set to 1.4 if the memory is unified and set to 1.0 if the memory is discrete.
|
||||
* - ``PYTORCH_MPS_PREFER_METAL``
|
||||
- If set to ``1``, force using metal kernels instead of using MPS Graph APIs. For now this is only used for matmul op.
|
||||
* - ``PYTORCH_ENABLE_MPS_FALLBACK``
|
||||
- If set to ``1``, full back operations to CPU when MPS does not support them.
|
||||
|
||||
.. note::
|
||||
|
||||
**high watermark ratio** is a hard limit for the total allowed allocations
|
||||
|
||||
- `0.0` : disables high watermark limit (may cause system failure if system-wide OOM occurs)
|
||||
- `1.0` : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
|
||||
- `>1.0`: allows limits beyond the device.recommendedMaxWorkingSetSize
|
||||
|
||||
e.g., value 0.95 means we allocate up to 95% of recommended maximum
|
||||
allocation size; beyond that, the allocations would fail with OOM error.
|
||||
|
||||
**low watermark ratio** is a soft limit to attempt limiting memory allocations up to the lower watermark
|
||||
level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
|
||||
Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
|
||||
e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
|
||||
allocation size.
|
||||
@ -164,10 +164,10 @@ horizontally and fused implementations as fusing vertically on top of that.
|
||||
In general, the performance ordering of the 3 implementations is fused > foreach > for-loop.
|
||||
So when applicable, we default to foreach over for-loop. Applicable means the foreach
|
||||
implementation is available, the user has not specified any implementation-specific kwargs
|
||||
(e.g., fused, foreach, differentiable), and all tensors are native and on CUDA. Note that
|
||||
while fused should be even faster than foreach, the implementations are newer and we would
|
||||
like to give them more bake-in time before flipping the switch everywhere. You are welcome
|
||||
to try them out though!
|
||||
(e.g., fused, foreach, differentiable), and all tensors are native. Note that while fused
|
||||
should be even faster than foreach, the implementations are newer and we would like to give
|
||||
them more bake-in time before flipping the switch everywhere. We summarize the stability status
|
||||
for each implementation on the second table below, you are welcome to try them out though!
|
||||
|
||||
Below is a table showing the available and default implementations of each algorithm:
|
||||
|
||||
@ -177,7 +177,7 @@ Below is a table showing the available and default implementations of each algor
|
||||
:delim: ;
|
||||
|
||||
:class:`Adadelta`;foreach;yes;no
|
||||
:class:`Adagrad`;foreach;yes;no
|
||||
:class:`Adagrad`;foreach;yes;yes (cpu only)
|
||||
:class:`Adam`;foreach;yes;yes
|
||||
:class:`AdamW`;foreach;yes;yes
|
||||
:class:`SparseAdam`;for-loop;no;no
|
||||
@ -188,7 +188,28 @@ Below is a table showing the available and default implementations of each algor
|
||||
:class:`RAdam`;foreach;yes;no
|
||||
:class:`RMSprop`;foreach;yes;no
|
||||
:class:`Rprop`;foreach;yes;no
|
||||
:class:`SGD`;foreach;yes;no
|
||||
:class:`SGD`;foreach;yes;yes (CPU and CUDA only)
|
||||
|
||||
Below table is showing the stability status for fused implementations:
|
||||
|
||||
.. csv-table::
|
||||
:header: "Algorithm", "CPU", "CUDA", "MPS"
|
||||
:widths: 25, 25, 25, 25
|
||||
:delim: ;
|
||||
|
||||
:class:`Adadelta`;unsupported;unsupported;unsupported
|
||||
:class:`Adagrad`;beta;unsupported;unsupported
|
||||
:class:`Adam`;beta;stable;beta
|
||||
:class:`AdamW`;beta;stable;beta
|
||||
:class:`SparseAdam`;unsupported;unsupported;unsupported
|
||||
:class:`Adamax`;unsupported;unsupported;unsupported
|
||||
:class:`ASGD`;unsupported;unsupported;unsupported
|
||||
:class:`LBFGS`;unsupported;unsupported;unsupported
|
||||
:class:`NAdam`;unsupported;unsupported;unsupported
|
||||
:class:`RAdam`;unsupported;unsupported;unsupported
|
||||
:class:`RMSprop`;unsupported;unsupported;unsupported
|
||||
:class:`Rprop`;unsupported;unsupported;unsupported
|
||||
:class:`SGD`;beta;beta;unsupported
|
||||
|
||||
How to adjust learning rate
|
||||
---------------------------
|
||||
|
||||
@ -21,6 +21,7 @@ If you find anything in this documentation that is missing, incorrect, or could
|
||||
|
||||
threading_environment_variables
|
||||
cuda_environment_variables
|
||||
mps_environment_variables
|
||||
debugging_environment_variables
|
||||
miscellaneous_environment_variables
|
||||
logging
|
||||
|
||||
95
setup.py
95
setup.py
@ -199,7 +199,6 @@
|
||||
# Builds pytorch as a wheel using libtorch.so from a seperate wheel
|
||||
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
|
||||
if sys.platform == "win32" and sys.maxsize.bit_length() == 31:
|
||||
@ -210,19 +209,6 @@ if sys.platform == "win32" and sys.maxsize.bit_length() == 31:
|
||||
|
||||
import platform
|
||||
|
||||
|
||||
def _get_package_path(package_name):
|
||||
loader = pkgutil.find_loader(package_name)
|
||||
if loader:
|
||||
# The package might be a namespace package, so get_data may fail
|
||||
try:
|
||||
file_path = loader.get_filename()
|
||||
return os.path.dirname(file_path)
|
||||
except AttributeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
BUILD_LIBTORCH_WHL = os.getenv("BUILD_LIBTORCH_WHL", "0") == "1"
|
||||
BUILD_PYTHON_ONLY = os.getenv("BUILD_PYTHON_ONLY", "0") == "1"
|
||||
|
||||
@ -237,6 +223,7 @@ if sys.version_info < python_min_version:
|
||||
import filecmp
|
||||
import glob
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
@ -253,15 +240,24 @@ from setuptools.dist import Distribution
|
||||
from tools.build_pytorch_libs import build_caffe2
|
||||
from tools.generate_torch_version import get_torch_version
|
||||
from tools.setup_helpers.cmake import CMake
|
||||
from tools.setup_helpers.env import (
|
||||
build_type,
|
||||
IS_DARWIN,
|
||||
IS_LINUX,
|
||||
IS_WINDOWS,
|
||||
LIBTORCH_PKG_NAME,
|
||||
)
|
||||
from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS
|
||||
from tools.setup_helpers.generate_linker_script import gen_linker_script
|
||||
|
||||
|
||||
def _get_package_path(package_name):
|
||||
spec = importlib.util.find_spec(package_name)
|
||||
if spec:
|
||||
# The package might be a namespace package, so get_data may fail
|
||||
try:
|
||||
loader = spec.loader
|
||||
if loader is not None:
|
||||
file_path = loader.get_filename() # type: ignore[attr-defined]
|
||||
return os.path.dirname(file_path)
|
||||
except AttributeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# set up appropriate env variables
|
||||
if BUILD_LIBTORCH_WHL:
|
||||
# Set up environment variables for ONLY building libtorch.so and not libtorch_python.so
|
||||
@ -271,7 +267,7 @@ if BUILD_LIBTORCH_WHL:
|
||||
|
||||
if BUILD_PYTHON_ONLY:
|
||||
os.environ["BUILD_LIBTORCHLESS"] = "ON"
|
||||
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path(LIBTORCH_PKG_NAME)}/lib"
|
||||
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib"
|
||||
|
||||
################################################################################
|
||||
# Parameters parsed from environment
|
||||
@ -347,9 +343,12 @@ cmake_python_include_dir = sysconfig.get_path("include")
|
||||
# Version, create_version_file, and package_name
|
||||
################################################################################
|
||||
|
||||
DEFAULT_PACKAGE_NAME = LIBTORCH_PKG_NAME if BUILD_LIBTORCH_WHL else "torch"
|
||||
package_name = os.getenv("TORCH_PACKAGE_NAME", "torch")
|
||||
LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "libtorch")
|
||||
if BUILD_LIBTORCH_WHL:
|
||||
package_name = LIBTORCH_PKG_NAME
|
||||
|
||||
|
||||
package_name = os.getenv("TORCH_PACKAGE_NAME", DEFAULT_PACKAGE_NAME)
|
||||
package_type = os.getenv("PACKAGE_TYPE", "wheel")
|
||||
version = get_torch_version()
|
||||
report(f"Building wheel {package_name}-{version}")
|
||||
@ -472,7 +471,6 @@ def build_deps():
|
||||
check_submodules()
|
||||
check_pydep("yaml", "pyyaml")
|
||||
build_python = not BUILD_LIBTORCH_WHL
|
||||
|
||||
build_caffe2(
|
||||
version=version,
|
||||
cmake_python_library=cmake_python_library,
|
||||
@ -1125,8 +1123,6 @@ def main():
|
||||
raise RuntimeError(
|
||||
"Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. Set one to 0 and rerun."
|
||||
)
|
||||
|
||||
# the list of runtime dependencies required by this built package
|
||||
install_requires = [
|
||||
"filelock",
|
||||
"typing-extensions>=4.8.0",
|
||||
@ -1141,7 +1137,7 @@ def main():
|
||||
install_requires.append("setuptools")
|
||||
|
||||
if BUILD_PYTHON_ONLY:
|
||||
install_requires.append(LIBTORCH_PKG_NAME)
|
||||
install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}")
|
||||
|
||||
use_prioritized_text = str(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD", ""))
|
||||
if (
|
||||
@ -1190,7 +1186,6 @@ def main():
|
||||
entry_points,
|
||||
extra_install_requires,
|
||||
) = configure_extension_build()
|
||||
|
||||
install_requires += extra_install_requires
|
||||
|
||||
extras_require = {
|
||||
@ -1219,6 +1214,7 @@ def main():
|
||||
"utils/data/*.pyi",
|
||||
"utils/data/datapipes/*.pyi",
|
||||
"lib/*.pdb",
|
||||
"lib/*shm*",
|
||||
"lib/torch_shm_manager",
|
||||
"lib/*.h",
|
||||
"include/*.h",
|
||||
@ -1383,15 +1379,15 @@ def main():
|
||||
"utils/model_dump/*.mjs",
|
||||
]
|
||||
|
||||
if BUILD_PYTHON_ONLY:
|
||||
if not BUILD_LIBTORCH_WHL:
|
||||
torch_package_data.extend(
|
||||
[
|
||||
"lib/libtorch_python*",
|
||||
"lib/*shm*",
|
||||
"lib/libtorch_global_deps*",
|
||||
"lib/libtorch_python.so",
|
||||
"lib/libtorch_python.dylib",
|
||||
"lib/libtorch_python.dll",
|
||||
]
|
||||
)
|
||||
else:
|
||||
if not BUILD_PYTHON_ONLY:
|
||||
torch_package_data.extend(
|
||||
[
|
||||
"lib/*.so*",
|
||||
@ -1442,28 +1438,18 @@ def main():
|
||||
"packaged/autograd/*",
|
||||
"packaged/autograd/templates/*",
|
||||
]
|
||||
package_data = {
|
||||
"torch": torch_package_data,
|
||||
}
|
||||
|
||||
if BUILD_LIBTORCH_WHL:
|
||||
modified_packages = []
|
||||
for package in packages:
|
||||
parts = package.split(".")
|
||||
if parts[0] == "torch":
|
||||
modified_packages.append(DEFAULT_PACKAGE_NAME + package[len("torch") :])
|
||||
packages = modified_packages
|
||||
package_dir = {LIBTORCH_PKG_NAME: "torch"}
|
||||
torch_package_dir_name = LIBTORCH_PKG_NAME
|
||||
package_data = {LIBTORCH_PKG_NAME: torch_package_data}
|
||||
extensions = []
|
||||
if not BUILD_LIBTORCH_WHL:
|
||||
package_data["torchgen"] = torchgen_package_data
|
||||
package_data["caffe2"] = [
|
||||
"python/serialized_test/data/operator_test/*.zip",
|
||||
]
|
||||
else:
|
||||
torch_package_dir_name = "torch"
|
||||
package_dir = {}
|
||||
package_data = {
|
||||
"torch": torch_package_data,
|
||||
"torchgen": torchgen_package_data,
|
||||
"caffe2": [
|
||||
"python/serialized_test/data/operator_test/*.zip",
|
||||
],
|
||||
}
|
||||
# no extensions in BUILD_LIBTORCH_WHL mode
|
||||
extensions = []
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
@ -1481,7 +1467,6 @@ def main():
|
||||
install_requires=install_requires,
|
||||
extras_require=extras_require,
|
||||
package_data=package_data,
|
||||
package_dir=package_dir,
|
||||
url="https://pytorch.org/",
|
||||
download_url="https://github.com/pytorch/pytorch/tags",
|
||||
author="PyTorch Team",
|
||||
|
||||
@ -1970,6 +1970,7 @@
|
||||
"EqualityConstraint",
|
||||
"GuardOnDataDependentSymNode",
|
||||
"LoggingShapeGuardPrinter",
|
||||
"SymExprPrinter",
|
||||
"RelaxedUnspecConstraint",
|
||||
"RuntimeAssert",
|
||||
"ShapeGuardPrinter",
|
||||
|
||||
@ -43,6 +43,7 @@ from torch.testing._internal.common_fsdp import (
|
||||
FSDPTestMultiThread,
|
||||
MLP,
|
||||
patch_post_backward,
|
||||
patch_reshard,
|
||||
patch_unshard,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
@ -372,7 +373,7 @@ class TestFullyShardCommunication(FSDPTest):
|
||||
)
|
||||
|
||||
|
||||
class TestFullyShardBackwardPrefetch(FSDPTest):
|
||||
class TestFullyShardPrefetch(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(4, torch.cuda.device_count())
|
||||
@ -578,6 +579,193 @@ class TestFullyShardBackwardPrefetch(FSDPTest):
|
||||
self.assertEqual(events, expected_events)
|
||||
events.clear()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_set_modules_to_forward_prefetch(self):
|
||||
n_layers = 4
|
||||
reshard_after_forward = True
|
||||
checkpoint_impl = "utils"
|
||||
model, _, inp = self._init_transformer(
|
||||
n_layers, reshard_after_forward, checkpoint_impl
|
||||
)
|
||||
|
||||
def set_forward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
|
||||
# Use model-specific knowledge to configure forward prefetching:
|
||||
# each transformer block (layer) prefetches for the next few
|
||||
for i, layer in enumerate(model.layers):
|
||||
if i >= len(model.layers) - num_to_prefetch:
|
||||
break
|
||||
layers_to_prefetch = [
|
||||
model.layers[i + j] for j in range(1, num_to_prefetch + 1)
|
||||
]
|
||||
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
|
||||
|
||||
events: List[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
reshard_with_record = self._get_reshard_with_record(
|
||||
FSDPParamGroup.reshard, events
|
||||
)
|
||||
post_backward_with_record = self._get_post_backward_with_record(
|
||||
FSDPParamGroup.post_backward, events
|
||||
)
|
||||
expected_backward_events = [
|
||||
# Default backward prefetching
|
||||
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
|
||||
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "layers.3", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
|
||||
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "layers.2", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
|
||||
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "layers.1", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
|
||||
("reshard", "layers.0", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
|
||||
("reshard", "", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "", TrainingState.POST_BACKWARD),
|
||||
]
|
||||
with patch_unshard(unshard_with_record), patch_reshard(
|
||||
reshard_with_record
|
||||
), patch_post_backward(post_backward_with_record):
|
||||
set_forward_prefetch(model, num_to_prefetch=1)
|
||||
loss = model(inp)
|
||||
expected_forward_events = [
|
||||
("unshard", "", TrainingState.FORWARD),
|
||||
# `layers.i` prefetches `layers.i+1`
|
||||
("unshard", "layers.0", TrainingState.FORWARD),
|
||||
("unshard", "layers.1", TrainingState.FORWARD),
|
||||
("reshard", "layers.0", TrainingState.FORWARD),
|
||||
("unshard", "layers.2", TrainingState.FORWARD),
|
||||
("reshard", "layers.1", TrainingState.FORWARD),
|
||||
("unshard", "layers.3", TrainingState.FORWARD),
|
||||
("reshard", "layers.2", TrainingState.FORWARD),
|
||||
("reshard", "layers.3", TrainingState.FORWARD),
|
||||
]
|
||||
self.assertEqual(events, expected_forward_events)
|
||||
events.clear()
|
||||
loss.sum().backward()
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
set_forward_prefetch(model, num_to_prefetch=2)
|
||||
loss = model(inp)
|
||||
expected_forward_events = [
|
||||
("unshard", "", TrainingState.FORWARD),
|
||||
# `layers.i` prefetches `layers.i+1` and `layers.i+2`
|
||||
("unshard", "layers.0", TrainingState.FORWARD),
|
||||
("unshard", "layers.1", TrainingState.FORWARD),
|
||||
("unshard", "layers.2", TrainingState.FORWARD),
|
||||
("reshard", "layers.0", TrainingState.FORWARD),
|
||||
("unshard", "layers.3", TrainingState.FORWARD),
|
||||
("reshard", "layers.1", TrainingState.FORWARD),
|
||||
("reshard", "layers.2", TrainingState.FORWARD),
|
||||
("reshard", "layers.3", TrainingState.FORWARD),
|
||||
]
|
||||
self.assertEqual(events, expected_forward_events)
|
||||
events.clear()
|
||||
loss.sum().backward()
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_set_modules_to_backward_prefetch(self):
|
||||
n_layers = 4
|
||||
reshard_after_forward = True
|
||||
checkpoint_impl = "utils"
|
||||
model, _, inp = self._init_transformer(
|
||||
n_layers, reshard_after_forward, checkpoint_impl
|
||||
)
|
||||
|
||||
def set_backward_prefetch(model: Transformer, num_to_prefetch: int) -> None:
|
||||
# Use model-specific knowledge to configure backward prefetching:
|
||||
# each transformer block (layer) prefetches for the previous few
|
||||
for i, layer in enumerate(model.layers):
|
||||
if i < num_to_prefetch:
|
||||
continue
|
||||
layers_to_prefetch = [
|
||||
model.layers[i - j] for j in range(1, num_to_prefetch + 1)
|
||||
]
|
||||
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
|
||||
|
||||
events: List[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
reshard_with_record = self._get_reshard_with_record(
|
||||
FSDPParamGroup.reshard, events
|
||||
)
|
||||
post_backward_with_record = self._get_post_backward_with_record(
|
||||
FSDPParamGroup.post_backward, events
|
||||
)
|
||||
expected_forward_events = [
|
||||
# Default forward prefetching
|
||||
("unshard", "", TrainingState.FORWARD), # root
|
||||
("unshard", "layers.0", TrainingState.FORWARD),
|
||||
("reshard", "layers.0", TrainingState.FORWARD),
|
||||
("unshard", "layers.1", TrainingState.FORWARD),
|
||||
("reshard", "layers.1", TrainingState.FORWARD),
|
||||
("unshard", "layers.2", TrainingState.FORWARD),
|
||||
("reshard", "layers.2", TrainingState.FORWARD),
|
||||
("unshard", "layers.3", TrainingState.FORWARD),
|
||||
("reshard", "layers.3", TrainingState.FORWARD),
|
||||
]
|
||||
with patch_unshard(unshard_with_record), patch_reshard(
|
||||
reshard_with_record
|
||||
), patch_post_backward(post_backward_with_record):
|
||||
set_backward_prefetch(model, num_to_prefetch=1)
|
||||
loss = model(inp)
|
||||
self.assertEqual(events, expected_forward_events)
|
||||
events.clear()
|
||||
loss.sum().backward()
|
||||
expected_backward_events = [
|
||||
# Root prefetches `layers.3` per default
|
||||
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
|
||||
# `layers.i` prefetches for `layers.i-1` (same as default)
|
||||
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "layers.3", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
|
||||
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "layers.2", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
|
||||
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "layers.1", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
|
||||
("reshard", "layers.0", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
|
||||
("reshard", "", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "", TrainingState.POST_BACKWARD),
|
||||
]
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
set_backward_prefetch(model, num_to_prefetch=2)
|
||||
loss = model(inp)
|
||||
self.assertEqual(events, expected_forward_events)
|
||||
events.clear()
|
||||
loss.sum().backward()
|
||||
expected_backward_events = [
|
||||
# Root prefetches `layers.3` per default
|
||||
("unshard", "layers.3", TrainingState.PRE_BACKWARD),
|
||||
# `layers.i` prefetches for `layers.i-1` and `layers.i-2`
|
||||
("unshard", "layers.2", TrainingState.PRE_BACKWARD),
|
||||
("unshard", "layers.1", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "layers.3", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.3", TrainingState.POST_BACKWARD),
|
||||
("unshard", "layers.0", TrainingState.PRE_BACKWARD),
|
||||
("reshard", "layers.2", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.2", TrainingState.POST_BACKWARD),
|
||||
("reshard", "layers.1", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.1", TrainingState.POST_BACKWARD),
|
||||
("reshard", "layers.0", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "layers.0", TrainingState.POST_BACKWARD),
|
||||
("reshard", "", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "", TrainingState.POST_BACKWARD),
|
||||
]
|
||||
self.assertEqual(events, expected_backward_events)
|
||||
events.clear()
|
||||
|
||||
def _init_transformer(
|
||||
self,
|
||||
n_layers: int,
|
||||
@ -614,6 +802,21 @@ class TestFullyShardBackwardPrefetch(FSDPTest):
|
||||
|
||||
return unshard_with_record
|
||||
|
||||
def _get_reshard_with_record(
|
||||
self, orig_reshard: Callable, events: List[EventType]
|
||||
) -> Callable:
|
||||
def reshard_with_record(self, *args, **kwargs):
|
||||
nonlocal events
|
||||
if (
|
||||
self._training_state == TrainingState.FORWARD
|
||||
and not self._reshard_after_forward
|
||||
): # skip no-ops
|
||||
return
|
||||
events.append(("reshard", self._module_fqn, self._training_state))
|
||||
return orig_reshard(self, *args, **kwargs)
|
||||
|
||||
return reshard_with_record
|
||||
|
||||
def _get_post_backward_with_record(
|
||||
self, orig_post_backward: Callable, events: List[EventType]
|
||||
) -> Callable:
|
||||
|
||||
@ -1,16 +1,30 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
|
||||
import contextlib
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.testing
|
||||
from torch import nn
|
||||
from torch._dynamo import compiled_autograd
|
||||
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._composable.fsdp._fsdp_common import TrainingState
|
||||
from torch.distributed._composable.fsdp._fsdp_init import (
|
||||
_get_managed_modules,
|
||||
_get_managed_states,
|
||||
)
|
||||
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, MLP
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
ModelArgs,
|
||||
Transformer,
|
||||
)
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
@ -64,6 +78,10 @@ class TestFullyShardCompileCompute(FSDPTest):
|
||||
|
||||
|
||||
class TestFullyShardCompile(FSDPTest):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(2, torch.cuda.device_count())
|
||||
|
||||
def test_dynamo_trace_use_training_state(self):
|
||||
torch._dynamo.reset()
|
||||
# Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager.
|
||||
@ -100,6 +118,174 @@ class TestFullyShardCompile(FSDPTest):
|
||||
self.assertEqual(cnt.op_count, 1)
|
||||
self.assertEqual(len(cnt.graphs), 1)
|
||||
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
@torch._functorch.config.patch(recompute_views=True)
|
||||
def _test_traceable_fsdp(
|
||||
self, model_init_fn, input_creation_fn, backend, fullgraph
|
||||
):
|
||||
n_iter = 10
|
||||
|
||||
def compiler_fn(compiled_autograd_backend):
|
||||
def _fn(gm):
|
||||
# fullgraph=True because graph-break in Compiled Autograd BWD graph is not supported by Traceable FSDP2 yet
|
||||
# (main difficulty comes from queue_callback not working well when BWD has graph break).
|
||||
return torch.compile(
|
||||
gm, backend=compiled_autograd_backend, fullgraph=True
|
||||
)
|
||||
|
||||
return _fn
|
||||
|
||||
def run_all_iters(model, optim, compiled_autograd_backend=None):
|
||||
torch.manual_seed(42)
|
||||
losses = []
|
||||
for i in range(n_iter):
|
||||
optim.zero_grad(set_to_none=True)
|
||||
inp = input_creation_fn()
|
||||
if compiled_autograd_backend is not None:
|
||||
maybe_compiled_autograd_ctx = compiled_autograd.enable(
|
||||
compiler_fn(compiled_autograd_backend)
|
||||
)
|
||||
else:
|
||||
maybe_compiled_autograd_ctx = contextlib.nullcontext()
|
||||
with maybe_compiled_autograd_ctx:
|
||||
out = model(inp)
|
||||
loss = out.sum()
|
||||
losses.append(loss.item())
|
||||
loss.backward()
|
||||
optim.step()
|
||||
torch.cuda.synchronize()
|
||||
return losses
|
||||
|
||||
def test_compiled():
|
||||
model, optim = model_init_fn()
|
||||
# FSDP2 does lazy init using 1st run, so run it once to init using eager mode
|
||||
run_all_iters(model, optim, 1)
|
||||
|
||||
model_compiled = torch.compile(model, backend=backend, fullgraph=True)
|
||||
res = run_all_iters(
|
||||
model_compiled, optim, compiled_autograd_backend=backend
|
||||
)
|
||||
optim.zero_grad(set_to_none=True)
|
||||
return res
|
||||
|
||||
def test_eager():
|
||||
model, optim = model_init_fn()
|
||||
# FSDP2 does lazy init using 1st run, so run it once to init using eager mode
|
||||
run_all_iters(model, optim, 1)
|
||||
|
||||
res = run_all_iters(model, optim)
|
||||
optim.zero_grad(set_to_none=True)
|
||||
return res
|
||||
|
||||
losses_compiled = test_compiled()
|
||||
losses_eager = test_eager()
|
||||
for loss_compiled, loss_eager in zip(losses_compiled, losses_eager):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
torch.tensor(loss_compiled), torch.tensor(loss_eager), rtol=1e-3
|
||||
),
|
||||
f"{loss_compiled} vs {loss_eager}",
|
||||
)
|
||||
|
||||
def _create_simple_mlp_factory_fns(self):
|
||||
hidden_dim = 16
|
||||
|
||||
def model_init_fn():
|
||||
torch.manual_seed(0)
|
||||
fsdp_config = {}
|
||||
model = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
|
||||
)
|
||||
fully_shard(model, reshard_after_forward=True, **fsdp_config)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-6)
|
||||
return model, optim
|
||||
|
||||
def input_creation_fn():
|
||||
torch.manual_seed(0)
|
||||
inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
|
||||
return inp
|
||||
|
||||
return model_init_fn, input_creation_fn
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_simple_mlp_fullgraph_backend_eager(self):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_simple_mlp_factory_fns(), "eager", fullgraph=True
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_simple_mlp_fullgraph_backend_aot_eager(self):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_simple_mlp_fullgraph_backend_inductor(self):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_simple_mlp_factory_fns(), "inductor", fullgraph=True
|
||||
)
|
||||
|
||||
def _create_transformer_factory_fns(self):
|
||||
hidden_dim = 16
|
||||
|
||||
def model_init_fn():
|
||||
torch.manual_seed(0)
|
||||
fsdp_config = {}
|
||||
mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
model_args = ModelArgs(
|
||||
dim=hidden_dim,
|
||||
n_layers=2,
|
||||
n_heads=1,
|
||||
vocab_size=1024,
|
||||
)
|
||||
model = Transformer(model_args)
|
||||
for layer_id, mod in enumerate(model.layers):
|
||||
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
|
||||
model.layers[layer_id] = mod
|
||||
model = fully_shard(
|
||||
model, mesh=mesh, reshard_after_forward=True, **fsdp_config
|
||||
)
|
||||
optim = torch.optim.SGD(model.parameters(), lr=1e-6)
|
||||
return model, optim
|
||||
|
||||
def input_creation_fn():
|
||||
torch.manual_seed(0)
|
||||
inp = torch.zeros(
|
||||
(2, hidden_dim),
|
||||
device="cuda",
|
||||
requires_grad=False,
|
||||
dtype=torch.long,
|
||||
)
|
||||
return inp
|
||||
|
||||
return model_init_fn, input_creation_fn
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_transformer_fullgraph_backend_eager(self):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_transformer_factory_fns(), "eager", fullgraph=True
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_transformer_fullgraph_backend_aot_eager(self):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_transformer_factory_fns(), "aot_eager", fullgraph=True
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_transformer_fullgraph_backend_inductor(self):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_transformer_factory_fns(), "inductor", fullgraph=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import functools
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
@ -7,6 +8,7 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import (
|
||||
FSDPTest,
|
||||
@ -23,15 +25,6 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fully_shard_training_overlap(self):
|
||||
class LinearWithSleep(nn.Module):
|
||||
def __init__(self, dim: int, sleep_ms: int):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn((dim, dim)))
|
||||
self.sleep_ms = sleep_ms
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms))
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Use non-trivial comm. time but still shorter than compute time
|
||||
@ -44,7 +37,7 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
fully_shard(model, reshard_after_forward=True)
|
||||
|
||||
orig_all_gather_into_tensor = dist.all_gather_into_tensor
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
orig_reduce_scatter_tensor = dist.reduce_scatter_tensor
|
||||
comm_stream = torch.cuda.Stream()
|
||||
|
||||
def delay_collective():
|
||||
@ -61,7 +54,7 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
|
||||
def delayed_reduce_scatter(*args, **kwargs):
|
||||
delay_collective()
|
||||
return orig_reduce_scatter(*args, **kwargs)
|
||||
return orig_reduce_scatter_tensor(*args, **kwargs)
|
||||
|
||||
inp = torch.randn((2, dim), device="cuda")
|
||||
loss = model(inp).sum() # warmup CUDA and allocator
|
||||
@ -92,6 +85,63 @@ class TestFullyShardOverlap(FSDPTest):
|
||||
)
|
||||
self.assertLessEqual(fwd_bwd_time, expected_fwd_time + expected_bwd_time)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_fully_shard_post_optim_event_overlap(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Use non-trivial comm. time but still shorter than compute time
|
||||
dim, compute_sleep_ms, comm_sleep_ms = (4, 25, 10)
|
||||
# Define the model to have a high-compute linear followed by a
|
||||
# low-compute linear, where only the low-compute linear uses FSDP
|
||||
model = nn.Sequential(
|
||||
LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim)
|
||||
).cuda()
|
||||
fully_shard(model[1], reshard_after_forward=False)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
orig_all_gather_into_tensor = dist.all_gather_into_tensor
|
||||
|
||||
def delayed_all_gather(*args, **kwargs):
|
||||
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
|
||||
return orig_all_gather_into_tensor(*args, **kwargs)
|
||||
|
||||
inp = torch.randn((2, dim), device="cuda")
|
||||
|
||||
def run_train_steps(num_iters: int, use_post_optim_event: bool):
|
||||
for _ in range(num_iters):
|
||||
optim.zero_grad()
|
||||
with patch_all_gather(delayed_all_gather):
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
with implicit_replication():
|
||||
optim.step()
|
||||
if use_post_optim_event:
|
||||
post_optim_event = torch.cuda.current_stream().record_event()
|
||||
model[1].set_post_optim_event(post_optim_event)
|
||||
|
||||
run_train_steps(1, False) # warmup CUDA and allocator
|
||||
num_iters = 5
|
||||
baseline_time = self._time_fn(
|
||||
functools.partial(run_train_steps, num_iters, False)
|
||||
)
|
||||
test_time = self._time_fn(functools.partial(run_train_steps, num_iters, True))
|
||||
|
||||
buffer_ms = 4 # CPU delays and copies
|
||||
# Baseline: FSDP all-gather is exposed since the FSDP module waits for
|
||||
# the current stream and hence the high-compute linear
|
||||
self.assertLessEqual(
|
||||
baseline_time,
|
||||
num_iters * (3 * compute_sleep_ms + comm_sleep_ms + buffer_ms),
|
||||
)
|
||||
# Test: FSDP all-gather is overlapped with the high-compute linear
|
||||
# since the FSDP module only waits for the post-optim event (except on
|
||||
# the 1st iteration when no event has been recorded)
|
||||
expected_test_time = (
|
||||
num_iters * (3 * compute_sleep_ms + buffer_ms) + comm_sleep_ms
|
||||
)
|
||||
self.assertLessEqual(test_time, expected_test_time)
|
||||
self.assertGreater(baseline_time, expected_test_time)
|
||||
|
||||
def _time_fn(self, fn: Callable):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
@ -123,5 +173,15 @@ class Matmul(torch.autograd.Function):
|
||||
return grad_input, grad_weight, None
|
||||
|
||||
|
||||
class LinearWithSleep(nn.Module):
|
||||
def __init__(self, dim: int, sleep_ms: int):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn((dim, dim)))
|
||||
self.sleep_ms = sleep_ms
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return nn.functional.relu(Matmul.apply(x, self.weight, self.sleep_ms))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import Iterable, List, Tuple, Type, Union
|
||||
|
||||
@ -337,7 +338,6 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
return
|
||||
assert device_type in ("cuda", "cpu"), f"{device_type}"
|
||||
torch.manual_seed(42)
|
||||
lin_dim = 32
|
||||
vocab_size = 1024
|
||||
model_args = ModelArgs(
|
||||
n_layers=3,
|
||||
@ -494,6 +494,85 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
_optim.step()
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_explicit_prefetching(self):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(n_layers=8, dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
ref_model = replicate(copy.deepcopy(model).cuda())
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for layer in itertools.chain(model.layers, [model]):
|
||||
fully_shard(layer)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
num_to_forward_prefetch = num_to_backward_prefetch = 2
|
||||
for i, layer in enumerate(model.layers):
|
||||
if i >= len(model.layers) - num_to_forward_prefetch:
|
||||
break
|
||||
layers_to_prefetch = [
|
||||
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
|
||||
]
|
||||
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if i < num_to_backward_prefetch:
|
||||
continue
|
||||
layers_to_prefetch = [
|
||||
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
|
||||
]
|
||||
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
|
||||
for iter_idx in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad()
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
_optim.step()
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_post_optim_event(self):
|
||||
torch.manual_seed(42)
|
||||
model_args = ModelArgs(dropout_p=0.0)
|
||||
model = Transformer(model_args)
|
||||
ref_model = replicate(copy.deepcopy(model).cuda())
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for layer in itertools.chain(model.layers, [model]):
|
||||
fully_shard(layer)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
def step_post_hook(
|
||||
fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs
|
||||
) -> None:
|
||||
post_optim_event = torch.cuda.current_stream().record_event()
|
||||
fsdp_module.set_post_optim_event(post_optim_event)
|
||||
|
||||
optim.register_step_post_hook(functools.partial(step_post_hook, model))
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
|
||||
# Track all losses and check for equality at the end to avoid a CPU
|
||||
# sync point after each iteration
|
||||
ref_losses: List[torch.Tensor] = []
|
||||
losses: List[torch.Tensor] = []
|
||||
for iter_idx in range(10):
|
||||
ref_optim.zero_grad()
|
||||
ref_losses.append(ref_model(inp).sum())
|
||||
ref_losses[-1].backward()
|
||||
ref_optim.step()
|
||||
for iter_idx in range(10):
|
||||
optim.zero_grad()
|
||||
losses.append(model(inp).sum())
|
||||
losses[-1].backward()
|
||||
optim.step()
|
||||
# Sleep after the optimizer step to allow CPU to run ahead into the
|
||||
# next iteration's forward, exercising the post-optim stream sync
|
||||
torch.cuda._sleep(int(25 * get_cycles_per_ms()))
|
||||
for ref_loss, loss in zip(ref_losses, losses):
|
||||
self.assertEqual(ref_loss, loss)
|
||||
|
||||
|
||||
class TestFullyShard1DTrainingCompose(FSDPTest):
|
||||
@property
|
||||
|
||||
@ -279,12 +279,16 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
||||
return code
|
||||
|
||||
def test_bucketing_coalesced_op(self):
|
||||
torch._inductor.config._fuse_ddp_communication_passes = [
|
||||
@torch._inductor.config.patch(
|
||||
_fuse_ddp_communication_passes=[
|
||||
"fuse_ddp_with_coalesced_op",
|
||||
"schedule_comm_wait",
|
||||
]
|
||||
|
||||
)
|
||||
# todo: This pass mucks things up since Inductor thinks its inference
|
||||
# and can apply this. Should turn off these passes in compiled autograd
|
||||
@torch._inductor.config.patch(reorder_for_locality=False)
|
||||
def test_bucketing_coalesced_op(self):
|
||||
# Gradient is None
|
||||
code = self._test_bucketing()
|
||||
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
||||
@ -311,12 +315,16 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
|
||||
fc.run(code)
|
||||
|
||||
def test_bucketing_concat_op(self):
|
||||
torch._inductor.config._fuse_ddp_communication_passes = [
|
||||
@torch._inductor.config.patch(
|
||||
_fuse_ddp_communication_passes=[
|
||||
"fuse_ddp_with_concat_op",
|
||||
"schedule_comm_wait",
|
||||
]
|
||||
|
||||
)
|
||||
# todo: This pass mucks things up since Inductor thinks its inference
|
||||
# and can apply this. Should turn off these passes in compiled autograd
|
||||
@torch._inductor.config.patch(reorder_for_locality=False)
|
||||
def test_bucketing_concat_op(self):
|
||||
# Gradient is None
|
||||
code = self._test_bucketing()
|
||||
self.assertEqual(counters["inductor"]["ddp_buckets"], 3)
|
||||
|
||||
@ -116,6 +116,9 @@ class TestCommMode(TestCase):
|
||||
|
||||
@requires_nccl()
|
||||
def test_comm_mode_with_c10d(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
world_pg = self.world_pg
|
||||
|
||||
inp = torch.rand(2, 8, 16).cuda()
|
||||
|
||||
@ -33,7 +33,11 @@ from torch.distributed.checkpoint.state_dict import (
|
||||
set_optimizer_state_dict,
|
||||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
|
||||
from torch.distributed.fsdp import (
|
||||
FullyShardedDataParallel as FSDP,
|
||||
ShardingStrategy,
|
||||
StateDictType,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.optim import _apply_optimizer_in_backward
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -70,7 +74,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
return min(4, torch.cuda.device_count())
|
||||
|
||||
def _test_save_load(
|
||||
self,
|
||||
@ -567,55 +571,71 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
|
||||
self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model))
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_broadcast_from_rank0(self) -> None:
|
||||
def inner_test(wrapper):
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
fsdp_model = wrapper(copy.deepcopy(model))
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
|
||||
def _test_broadcast_from_rank0(self, wrapper) -> None:
|
||||
model = CompositeParamModel(device=torch.device("cuda"))
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
fsdp_model = wrapper(copy.deepcopy(model))
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
|
||||
|
||||
batch = torch.rand(8, 100, device="cuda")
|
||||
model(batch).sum().backward()
|
||||
optim.step()
|
||||
states, optim_states = get_state_dict(model, optim)
|
||||
batch = torch.rand(8, 100, device="cuda")
|
||||
model(batch).sum().backward()
|
||||
optim.step()
|
||||
states, optim_states = get_state_dict(model, optim)
|
||||
|
||||
fsdp_model(batch).sum().backward()
|
||||
fsdp_optim.step()
|
||||
fsdp_model(batch).sum().backward()
|
||||
fsdp_optim.step()
|
||||
|
||||
def check(equal):
|
||||
fsdp_states = get_model_state_dict(
|
||||
fsdp_model,
|
||||
options=StateDictOptions(full_state_dict=True),
|
||||
)
|
||||
fsdp_optim_states = get_optimizer_state_dict(
|
||||
fsdp_model,
|
||||
fsdp_optim,
|
||||
options=StateDictOptions(full_state_dict=True),
|
||||
)
|
||||
if equal:
|
||||
self.assertEqual(states, fsdp_states)
|
||||
self.assertEqual(optim_states, fsdp_optim_states)
|
||||
else:
|
||||
self.assertNotEqual(states, fsdp_states)
|
||||
self.assertNotEqual(optim_states, fsdp_optim_states)
|
||||
|
||||
check(equal=True)
|
||||
fsdp_model(batch).sum().backward()
|
||||
fsdp_optim.step()
|
||||
check(equal=False)
|
||||
|
||||
# Drop the states to simulate loading from rank0
|
||||
if dist.get_rank() > 0:
|
||||
load_states = {}
|
||||
load_states2 = {}
|
||||
load_optim_states = {}
|
||||
def check(equal):
|
||||
fsdp_states = get_model_state_dict(
|
||||
fsdp_model,
|
||||
options=StateDictOptions(full_state_dict=True),
|
||||
)
|
||||
fsdp_optim_states = get_optimizer_state_dict(
|
||||
fsdp_model,
|
||||
fsdp_optim,
|
||||
options=StateDictOptions(full_state_dict=True),
|
||||
)
|
||||
if equal:
|
||||
self.assertEqual(states, fsdp_states)
|
||||
self.assertEqual(optim_states, fsdp_optim_states)
|
||||
else:
|
||||
load_states = copy.deepcopy(states)
|
||||
load_states2 = copy.deepcopy(states)
|
||||
load_optim_states = copy.deepcopy(optim_states)
|
||||
self.assertNotEqual(states, fsdp_states)
|
||||
self.assertNotEqual(optim_states, fsdp_optim_states)
|
||||
|
||||
check(equal=True)
|
||||
fsdp_model(batch).sum().backward()
|
||||
fsdp_optim.step()
|
||||
check(equal=False)
|
||||
|
||||
# Drop the states to simulate loading from rank0
|
||||
if dist.get_rank() > 0:
|
||||
load_states = {}
|
||||
load_states2 = {}
|
||||
load_optim_states = {}
|
||||
else:
|
||||
load_states = copy.deepcopy(states)
|
||||
load_states2 = copy.deepcopy(states)
|
||||
load_optim_states = copy.deepcopy(optim_states)
|
||||
|
||||
set_model_state_dict(
|
||||
fsdp_model,
|
||||
model_state_dict=load_states,
|
||||
options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
|
||||
)
|
||||
set_optimizer_state_dict(
|
||||
fsdp_model,
|
||||
fsdp_optim,
|
||||
optim_state_dict=load_optim_states,
|
||||
options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
|
||||
)
|
||||
|
||||
check(equal=True)
|
||||
# Verify the `strict` flag.
|
||||
load_states = load_states2
|
||||
if load_states:
|
||||
key = next(iter(load_states.keys()))
|
||||
load_states.pop(key)
|
||||
with self.assertRaisesRegex(RuntimeError, "Missing key"):
|
||||
set_model_state_dict(
|
||||
fsdp_model,
|
||||
model_state_dict=load_states,
|
||||
@ -623,30 +643,10 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
broadcast_from_rank0=True, full_state_dict=True
|
||||
),
|
||||
)
|
||||
set_optimizer_state_dict(
|
||||
fsdp_model,
|
||||
fsdp_optim,
|
||||
optim_state_dict=load_optim_states,
|
||||
options=StateDictOptions(
|
||||
broadcast_from_rank0=True, full_state_dict=True
|
||||
),
|
||||
)
|
||||
|
||||
check(equal=True)
|
||||
# Verify the `strict` flag.
|
||||
load_states = load_states2
|
||||
if load_states:
|
||||
key = next(iter(load_states.keys()))
|
||||
load_states.pop(key)
|
||||
with self.assertRaisesRegex(RuntimeError, "Missing key"):
|
||||
set_model_state_dict(
|
||||
fsdp_model,
|
||||
model_state_dict=load_states,
|
||||
options=StateDictOptions(
|
||||
broadcast_from_rank0=True, full_state_dict=True
|
||||
),
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_broadcast_from_rank0(self) -> None:
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
self.run_subtests(
|
||||
{
|
||||
@ -655,7 +655,24 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
functools.partial(FSDP, device_mesh=device_mesh),
|
||||
]
|
||||
},
|
||||
inner_test,
|
||||
self._test_broadcast_from_rank0,
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_broadcast_from_rank0_hsdp(self) -> None:
|
||||
device_mesh = init_device_mesh("cuda", (2, self.world_size // 2))
|
||||
self.run_subtests(
|
||||
{
|
||||
"wrapper": [
|
||||
functools.partial(
|
||||
FSDP,
|
||||
device_mesh=device_mesh,
|
||||
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
|
||||
),
|
||||
]
|
||||
},
|
||||
self._test_broadcast_from_rank0,
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@ -851,6 +868,33 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
):
|
||||
get_model_state_dict(model)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_shared_weight(self):
|
||||
class TiedEmbeddingModel(nn.Module):
|
||||
def __init__(self, vocab_size, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
||||
self.decoder = nn.Linear(embedding_dim, vocab_size)
|
||||
self.decoder.weight = self.embedding.weight # Tying weights
|
||||
|
||||
def forward(self, input):
|
||||
input = (input * 10).to(torch.int)
|
||||
embedded = self.embedding(input)
|
||||
output = self.decoder(embedded)
|
||||
return output
|
||||
|
||||
def init_model_optim():
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
|
||||
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
|
||||
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
|
||||
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
|
||||
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3)
|
||||
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
|
||||
|
||||
self._test_save_load(init_model_optim)
|
||||
|
||||
|
||||
class TestNoComm(MultiProcessTestCase):
|
||||
def setUp(self) -> None:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
@ -123,5 +124,13 @@ class TestMedatadaIndex(TestCase):
|
||||
find_state_dict_object(state_dict, MetadataIndex("st", [1]))
|
||||
|
||||
|
||||
class TestTensorProperties(TestCase):
|
||||
def test_create_from_tensor_correct_device(self):
|
||||
t = torch.randn([10, 2], device="cpu")
|
||||
t.is_pinned = MagicMock(return_value=True)
|
||||
TensorProperties.create_from_tensor(t)
|
||||
t.is_pinned.assert_called_with(device=torch.device("cpu"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -92,6 +92,38 @@ def get_model(
|
||||
return m, inputs, outputs
|
||||
|
||||
|
||||
class MutatingModel(nn.Module):
|
||||
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
|
||||
super().__init__()
|
||||
self.ctx_manager = ctx_manager
|
||||
self.net = nn.Sequential(
|
||||
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
|
||||
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
||||
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
||||
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
|
||||
)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, inputs):
|
||||
self.state = 2
|
||||
return self.net(inputs) * self.state
|
||||
|
||||
|
||||
def get_mutating_model(
|
||||
device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
|
||||
):
|
||||
m = MutatingModel(
|
||||
in_feat=in_feat,
|
||||
hidden_feat=hidden_feat,
|
||||
out_feat=out_feat,
|
||||
ctx_manager=ctx_manager,
|
||||
).to(device)
|
||||
m.apply(init_weights)
|
||||
inputs = torch.rand(bsz, in_feat).to(device)
|
||||
outputs = m(inputs)
|
||||
return m, inputs, outputs
|
||||
|
||||
|
||||
class ToyInnerModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -484,6 +516,26 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
outputs = fsdp_m(inputs)
|
||||
self.assertTrue(same(correct_outputs, outputs))
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_fsdp_setattr(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
# Test with basic FSDP wrapping (outer wrap around whole model)
|
||||
m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}")
|
||||
fsdp_m = FSDP(m, use_orig_params=True)
|
||||
prof = torch._dynamo.utils.CompileProfiler()
|
||||
fsdp_m = torch.compile(fsdp_m, backend=prof, fullgraph=False)
|
||||
outputs = fsdp_m(inputs)
|
||||
self.assertTrue(same(correct_outputs, outputs))
|
||||
FileCheck().check("Torchdynamo Profiler Report").check(
|
||||
"Graph Breaks"
|
||||
).check_not(
|
||||
"setattr(FSDPManagedNNModuleVariable(MutatingModel), state, ...)"
|
||||
).check_not(
|
||||
"setattr(FSDPManagedNNModuleVariable(FullyShardedDataParallel), _is_root, ...)"
|
||||
).run(
|
||||
prof.report()
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_fsdp_inductor(self):
|
||||
|
||||
@ -60,8 +60,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_broadcast_inductor(self):
|
||||
"""
|
||||
Testing if broadcast works correctly when using inductor
|
||||
@ -94,8 +92,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_allreduce_inductor(self):
|
||||
"""
|
||||
This is matmul/cat/allreduce is a pattern we aim to optimize.
|
||||
@ -129,8 +125,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_allreduce_inductor_cudagraph_trees(self):
|
||||
"""
|
||||
Tests whether cudagraph trees support all_reduce from nccl
|
||||
@ -177,8 +171,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_eager_allreduce_inductor_wait(self):
|
||||
def eager_func(a, b, c, d, *, tag, ranks, group_size):
|
||||
x = torch.matmul(a, b)
|
||||
@ -218,8 +210,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_inductor_allreduce_eager_wait(self):
|
||||
def inductor_func(a, b, c, d, *, tag, ranks, group_size):
|
||||
x = torch.matmul(a, b)
|
||||
@ -256,8 +246,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_allreduce_input_buffer_reuse(self):
|
||||
def func(a, *, tag, ranks, group_size):
|
||||
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
|
||||
@ -275,8 +263,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_permute_tensor(self):
|
||||
def func(tensor, src_dst_pairs, *, tag, ranks, group_size):
|
||||
return _functional_collectives.permute_tensor(
|
||||
@ -304,8 +290,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_allgather_output_buffer_reuse(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
@ -329,8 +313,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_allgather_contiguous_input(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
@ -355,8 +337,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_allgather_into_tensor_inductor(self):
|
||||
"""
|
||||
This is matmul/cat/allreduce is a pattern we aim to optimize.
|
||||
@ -388,8 +368,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_reduce_scatter_tensor_inductor(self):
|
||||
def example(a, b, *, tag, ranks, group_size):
|
||||
c = torch.matmul(a, b)
|
||||
@ -418,8 +396,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_all_to_all_single_inductor(self):
|
||||
def example(
|
||||
inp,
|
||||
@ -488,8 +464,6 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
|
||||
@patch.object(torch._inductor.config, "compile_threads", 1)
|
||||
def test_all_to_all_single_inductor_split_sizes_none(self):
|
||||
def example(inp, *, tag, ranks, group_size):
|
||||
a2a = torch.ops.c10d_functional.all_to_all_single(
|
||||
|
||||
@ -19,7 +19,11 @@ from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint
|
||||
from torch.utils.checkpoint import (
|
||||
checkpoint,
|
||||
CheckpointPolicy,
|
||||
create_selective_checkpoint_contexts,
|
||||
)
|
||||
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
requires_distributed = functools.partial(
|
||||
@ -105,8 +109,11 @@ def op_count(gm):
|
||||
|
||||
|
||||
def _get_custom_policy(no_recompute_list=None):
|
||||
def _custom_policy(mode, func, *args, **kwargs):
|
||||
return func in no_recompute_list
|
||||
def _custom_policy(ctx, func, *args, **kwargs):
|
||||
if func in no_recompute_list:
|
||||
return CheckpointPolicy.MUST_SAVE
|
||||
else:
|
||||
return CheckpointPolicy.PREFER_RECOMPUTE
|
||||
|
||||
return _custom_policy
|
||||
|
||||
@ -530,7 +537,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
]
|
||||
return _pt2_selective_checkpoint_context_fn_gen(
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
)
|
||||
|
||||
@ -580,7 +587,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
]
|
||||
return _pt2_selective_checkpoint_context_fn_gen(
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
)
|
||||
|
||||
@ -650,7 +657,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def selective_checkpointing_context_fn():
|
||||
meta = {}
|
||||
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))
|
||||
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
||||
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(
|
||||
@ -698,7 +705,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self):
|
||||
def selective_checkpointing_context_fn(no_recompute_list):
|
||||
return _pt2_selective_checkpoint_context_fn_gen(
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
)
|
||||
|
||||
@ -751,7 +758,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
torch.ops.aten.mm.default,
|
||||
torch.ops.aten.sigmoid.default,
|
||||
]
|
||||
return _pt2_selective_checkpoint_context_fn_gen(
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list),
|
||||
)
|
||||
|
||||
@ -803,7 +810,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
torch.ops.aten.mm.default,
|
||||
torch.ops.aten.sigmoid.default,
|
||||
]
|
||||
return _pt2_selective_checkpoint_context_fn_gen(
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
)
|
||||
|
||||
@ -854,7 +861,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.sigmoid.default,
|
||||
]
|
||||
return _pt2_selective_checkpoint_context_fn_gen(
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
)
|
||||
|
||||
|
||||
@ -2746,26 +2746,6 @@ class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
wrapped_gm = backend.graphs[graph_idx]
|
||||
return wrapped_gm
|
||||
|
||||
def test_hessian_graph_break(self):
|
||||
counters.clear()
|
||||
|
||||
def wrapper_fn(x):
|
||||
return torch.func.hessian(torch.sin)(x)
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
expected = wrapper_fn(x)
|
||||
got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
||||
self.assertEqual(expected, got)
|
||||
self.assertEqual(len(counters["graph_break"]), 2)
|
||||
self.assertEqual(
|
||||
{
|
||||
"'skip function disable in file _dynamo/decorators.py'": 1,
|
||||
"call torch._dynamo.disable() wrapped function <function jacfwd.<locals>.wrapper_fn at 0xN>": 1,
|
||||
},
|
||||
{munge_exc(k): v for k, v in counters["graph_break"].items()},
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_hessian(self):
|
||||
counters.clear()
|
||||
|
||||
@ -2900,7 +2880,6 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_hessian_argnums(self):
|
||||
counters.clear()
|
||||
|
||||
@ -3046,7 +3025,6 @@ class GraphModule(torch.nn.Module):
|
||||
""" return (unflatten,)""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_hessian_disable_capture(self):
|
||||
counters.clear()
|
||||
|
||||
@ -3073,26 +3051,6 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_jacrev_graph_break(self):
|
||||
counters.clear()
|
||||
|
||||
def wrapper_fn(x):
|
||||
return torch.func.jacrev(torch.sin)(x)
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
expected = wrapper_fn(x)
|
||||
got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
||||
self.assertEqual(expected, got)
|
||||
self.assertEqual(len(counters["graph_break"]), 2)
|
||||
self.assertEqual(
|
||||
{
|
||||
"'skip function disable in file _dynamo/decorators.py'": 1,
|
||||
"call torch._dynamo.disable() wrapped function <function jacrev.<locals>.wrapper_fn at 0xN>": 1,
|
||||
},
|
||||
{munge_exc(k): v for k, v in counters["graph_break"].items()},
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_jacrev(self):
|
||||
counters.clear()
|
||||
|
||||
@ -3169,7 +3127,6 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_jacrev_two_tensors_argnums(self):
|
||||
counters.clear()
|
||||
|
||||
@ -3252,7 +3209,6 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_jacrev_has_aux(self):
|
||||
counters.clear()
|
||||
|
||||
@ -3337,7 +3293,6 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_jacrev_disable_capture(self):
|
||||
counters.clear()
|
||||
|
||||
@ -4284,26 +4239,6 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_jacfwd_graph_break(self):
|
||||
counters.clear()
|
||||
|
||||
def wrapper_fn(x):
|
||||
return torch.func.jacfwd(torch.sin)(x)
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
expected = wrapper_fn(x)
|
||||
got = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
||||
self.assertEqual(expected, got)
|
||||
self.assertEqual(len(counters["graph_break"]), 2)
|
||||
self.assertEqual(
|
||||
{
|
||||
"'skip function disable in file _dynamo/decorators.py'": 1,
|
||||
"call torch._dynamo.disable() wrapped function <function jacfwd.<locals>.wrapper_fn at 0xN>": 1,
|
||||
},
|
||||
{munge_exc(k): v for k, v in counters["graph_break"].items()},
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_jacfwd(self):
|
||||
counters.clear()
|
||||
|
||||
@ -4387,7 +4322,6 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_jacfwd_two_tensors_argnums(self):
|
||||
counters.clear()
|
||||
|
||||
@ -4477,7 +4411,6 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_jacfwd_has_aux(self):
|
||||
counters.clear()
|
||||
|
||||
@ -4572,7 +4505,6 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_jacfwd_randomness(self):
|
||||
counters.clear()
|
||||
|
||||
@ -4676,7 +4608,6 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_jacfwd_disable_capture(self):
|
||||
counters.clear()
|
||||
|
||||
|
||||
@ -47,7 +47,6 @@ from torch._dynamo.testing import (
|
||||
same,
|
||||
skipIfNotPy311,
|
||||
unsupported,
|
||||
xfailIfPy312,
|
||||
)
|
||||
from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
@ -8289,6 +8288,72 @@ def ___make_guard_fn():
|
||||
x = torch.zeros(100, dtype=torch.int64)
|
||||
f(x)
|
||||
|
||||
def test_out_variant_custom_op(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
lib.define(
|
||||
"split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()"
|
||||
)
|
||||
|
||||
@torch.library.impl(lib, "split_with_sizes_copy", "Meta")
|
||||
@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
|
||||
def split_with_sizes_copy(
|
||||
all_gather_output: torch.Tensor,
|
||||
all_gather_input_split_sizes: typing.List[int],
|
||||
dim: int,
|
||||
out: typing.List[torch.Tensor],
|
||||
) -> None:
|
||||
torch.split_with_sizes_copy(
|
||||
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
|
||||
)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f1(all_gather_output, all_gather_input_split_sizes, dim, out):
|
||||
return torch.ops.mylib.split_with_sizes_copy(
|
||||
all_gather_output, all_gather_input_split_sizes, dim, out=out
|
||||
)
|
||||
|
||||
all_gather_output = torch.randn(2, 272)
|
||||
all_gather_input_split_sizes = [128, 8, 128, 8]
|
||||
dim = 1
|
||||
out = [
|
||||
torch.empty(2, 128),
|
||||
torch.empty(2, 8),
|
||||
torch.empty(2, 128),
|
||||
torch.empty(2, 8),
|
||||
]
|
||||
f1(all_gather_output, all_gather_input_split_sizes, dim, out)
|
||||
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
lib.define(
|
||||
"chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
|
||||
)
|
||||
|
||||
@torch.library.impl(lib, "chunk_cat", "Meta")
|
||||
@torch.library.impl(lib, "chunk_cat", "CPU")
|
||||
def chunk_cat(
|
||||
tensors: typing.List[torch.Tensor],
|
||||
dim: int,
|
||||
num_chunks: int,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
torch._chunk_cat(tensors, dim, num_chunks, out=out)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f2(tensors, dim, num_chunks, out):
|
||||
return torch.ops.mylib.chunk_cat(tensors, dim, num_chunks, out=out)
|
||||
|
||||
x = torch.zeros(100, dtype=torch.int64)
|
||||
tensors = [
|
||||
torch.randn(16, 16),
|
||||
torch.randn(16),
|
||||
torch.randn(16, 16),
|
||||
torch.randn(16),
|
||||
]
|
||||
dim = 0
|
||||
num_chunks = 2
|
||||
out = torch.empty(2, 272)
|
||||
f2(tensors, dim, num_chunks, out)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_runtime_assert_replacement(self):
|
||||
@torch.compile(backend="aot_eager")
|
||||
@ -9946,10 +10011,6 @@ fn
|
||||
lambda mod: mod,
|
||||
)
|
||||
|
||||
# The following 2 tests fail due to https://github.com/python/cpython/issues/118013.
|
||||
# Tracked by https://github.com/pytorch/pytorch/issues/124302.
|
||||
# The xfails can be removed once Python 3.12 is updated on CI.
|
||||
@xfailIfPy312
|
||||
def test_outside_linear_module_free(self):
|
||||
# Compared to test_linear_module_free, the linear
|
||||
# layer is not the code object that is directly compiled.
|
||||
@ -9984,7 +10045,6 @@ fn
|
||||
gc.collect()
|
||||
self.assertTrue(cleared)
|
||||
|
||||
@xfailIfPy312
|
||||
def test_parameter_free(self):
|
||||
def model_inp_ctr():
|
||||
param = torch.nn.Parameter(torch.randn(100, 100))
|
||||
|
||||
@ -4781,6 +4781,9 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
res = opt_fn(x_weak, y)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@torch._functorch.config.patch(
|
||||
recompute_views=True,
|
||||
)
|
||||
def test_storage_resize_forward_full_graph(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -4839,8 +4842,7 @@ def forward(self, primals_1, primals_2):
|
||||
_foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None
|
||||
getitem = _foreach_copy[0]; _foreach_copy = None
|
||||
mm = torch.ops.aten.mm.default(getitem, getitem)
|
||||
t_1 = torch.ops.aten.t.default(getitem); getitem = None
|
||||
return [mm, t_1]""",
|
||||
return [mm, getitem]""",
|
||||
)
|
||||
self.assertEqual(out_ref, out_test)
|
||||
|
||||
|
||||
@ -334,6 +334,41 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
res = fn(input)
|
||||
self.assertIsInstance(res, BadNewTorchFunction)
|
||||
|
||||
def test_no_torch_function_recompiles(self):
|
||||
class NJT:
|
||||
def __repr__(self):
|
||||
return f"NJT(shape={self.shape})"
|
||||
|
||||
def __init__(self, values, offsets):
|
||||
self._values = values
|
||||
self._offsets = offsets
|
||||
|
||||
def sin(self):
|
||||
return torch.sin(self)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func == torch.sin:
|
||||
self = args[0]
|
||||
return NJT(func(self._values), self._offsets)
|
||||
raise AssertionError("should not get here")
|
||||
|
||||
values1 = torch.randn(10, 3, 4, requires_grad=True)
|
||||
values2 = torch.randn(10, 3, 4, requires_grad=True)
|
||||
offsets = torch.tensor([0, 3, 10])
|
||||
njt1 = NJT(values1, offsets)
|
||||
njt2 = NJT(values2, offsets)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
return torch.sin(x)
|
||||
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
f(njt1)
|
||||
f(njt2)
|
||||
|
||||
def test_base_torch_function_tracing(self):
|
||||
def fn(x):
|
||||
return torch.add(x, 1)
|
||||
@ -1616,15 +1651,15 @@ Eq(s10, s8)""",
|
||||
guard_str,
|
||||
"""\
|
||||
Eq(s3 - 1, s0)
|
||||
Eq(zf1, zf6)""",
|
||||
Eq(zf1, zf4)""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
guard_str,
|
||||
"""\
|
||||
Eq(s4 - 1, s1)
|
||||
Eq(s12 - 1, s7)
|
||||
Eq(s11, s9)""",
|
||||
Eq(s10 - 1, s5)
|
||||
Eq(s9, s7)""",
|
||||
)
|
||||
return gm
|
||||
|
||||
|
||||
@ -446,8 +446,6 @@ aten::_nested_from_padded_and_nested_example
|
||||
aten::_nested_from_padded_and_nested_example.out
|
||||
aten::_nested_get_jagged_dummy
|
||||
aten::_nested_get_lengths
|
||||
aten::_nested_get_max_seqlen
|
||||
aten::_nested_get_min_seqlen
|
||||
aten::_nested_get_offsets
|
||||
aten::_nested_get_ragged_idx
|
||||
aten::_nested_get_values
|
||||
|
||||
@ -111,13 +111,102 @@ class TestConverter(TestCase):
|
||||
|
||||
def test_aten_len(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
# aten::len.Tensor
|
||||
inp = (torch.ones(2, 3),)
|
||||
self._check_equal_ts_ep_converter(Module(), inp)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: List[int]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
# aten::len.t
|
||||
inp = ([1, 2, 3],)
|
||||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Dict[int, str]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
# aten::len.Dict_int
|
||||
inp = ({1: "a", 2: "b", 3: "c"},)
|
||||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Dict[bool, str]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
# aten::len.Dict_bool
|
||||
inp = ({True: "a", False: "b"},)
|
||||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Dict[float, str]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
# aten::len.Dict_float
|
||||
inp = ({1.2: "a", 3.4: "b"},)
|
||||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Dict[torch.Tensor, str]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
# aten::len.Dict_Tensor
|
||||
inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},)
|
||||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
# aten::len.str and aten::len.Dict_str are not supported
|
||||
# since torch._C._jit_flatten does not support str
|
||||
# inp = ("abcdefg",)
|
||||
# self._check_equal_ts_ep_converter(Module(), inp)
|
||||
# inp = ({"a": 1, "b": 2},)
|
||||
# self._check_equal_ts_ep_converter(Module(), inp)
|
||||
|
||||
def test_prim_min(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
x_len = len(x)
|
||||
y_len = len(y)
|
||||
|
||||
# prim::min.int
|
||||
len_int = min(x_len, y_len)
|
||||
|
||||
# prim::min.float
|
||||
len_float = int(min(x_len * 2.0, y_len * 2.0))
|
||||
|
||||
# prim::min.self_int
|
||||
len_self_int = min([x_len, y_len])
|
||||
|
||||
# prim::min.self_float
|
||||
len_self_float = int(min([x_len * 2.0, y_len * 2.0]))
|
||||
|
||||
# prim::min.float_int
|
||||
len_float_int = int(min(x_len * 2.0, y_len))
|
||||
|
||||
# prim::min.int_float
|
||||
len_int_float = int(min(x_len, y_len * 2.0))
|
||||
|
||||
return torch.ones(
|
||||
len_int
|
||||
+ len_float
|
||||
+ len_self_int
|
||||
+ len_self_float
|
||||
+ len_float_int
|
||||
+ len_int_float
|
||||
)
|
||||
|
||||
inp = (torch.randn(10, 2), torch.randn(5))
|
||||
self._check_equal_ts_ep_converter(Module(), inp)
|
||||
|
||||
def test_aten___getitem___list(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -659,6 +748,21 @@ class TestConverter(TestCase):
|
||||
# inp = (torch.randn([2, 3, 4]),)
|
||||
# self._check_equal_ts_ep_converter(func6, inp)
|
||||
|
||||
def test_prim_tolist(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> List[int]:
|
||||
return x.tolist()
|
||||
|
||||
inp = (torch.tensor([1, 2, 3]),)
|
||||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> List[List[int]]:
|
||||
return x.tolist()
|
||||
|
||||
inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
|
||||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -11,6 +11,7 @@ from torch._export.wrappers import _mark_strict_experimental
|
||||
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch.export._trace import _convert_ts_to_export_experimental
|
||||
from torch.export.experimental import _export_forward_backward
|
||||
|
||||
from torch.testing import FileCheck
|
||||
|
||||
@ -194,6 +195,76 @@ def forward(self, arg0_1, arg1_1):
|
||||
MDict, ({"0": torch.randn(4), "1": torch.randn(4)},)
|
||||
)
|
||||
|
||||
def test_joint_basic(self) -> None:
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
self.loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, x):
|
||||
return self.loss(
|
||||
self.linear(x).softmax(dim=0), torch.tensor([1.0, 0.0, 0.0])
|
||||
)
|
||||
|
||||
m = Module()
|
||||
example_inputs = (torch.randn(3),)
|
||||
m(*example_inputs)
|
||||
ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True)
|
||||
joint_ep = _export_forward_backward(ep)
|
||||
print(joint_ep)
|
||||
|
||||
"""
|
||||
ExportedProgram:
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"):
|
||||
# No stacktrace found for following nodes
|
||||
view: "f32[1, 3]" = torch.ops.aten.view.default(arg3_1, [1, 3]); arg3_1 = None
|
||||
t: "f32[3, 3]" = torch.ops.aten.t.default(arg0_1); arg0_1 = None
|
||||
addmm: "f32[1, 3]" = torch.ops.aten.addmm.default(arg1_1, view, t); arg1_1 = t = None
|
||||
view_1: "f32[3]" = torch.ops.aten.view.default(addmm, [3]); addmm = None
|
||||
_softmax: "f32[3]" = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None
|
||||
detach_1: "f32[3]" = torch.ops.aten.detach.default(_softmax)
|
||||
clone: "f32[3]" = torch.ops.aten.clone.default(arg2_1); arg2_1 = None
|
||||
detach_5: "f32[3]" = torch.ops.aten.detach.default(clone); clone = None
|
||||
_log_softmax: "f32[3]" = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None
|
||||
detach_12: "f32[3]" = torch.ops.aten.detach.default(_log_softmax)
|
||||
mul: "f32[3]" = torch.ops.aten.mul.Tensor(_log_softmax, detach_5); _log_softmax = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
|
||||
neg: "f32[]" = torch.ops.aten.neg.default(sum_1); sum_1 = None
|
||||
div: "f32[]" = torch.ops.aten.div.Scalar(neg, 1); neg = None
|
||||
ones_like: "f32[]" = torch.ops.aten.ones_like.default(div, pin_memory = False, memory_format = torch.preserve_format)
|
||||
div_1: "f32[]" = torch.ops.aten.div.Scalar(ones_like, 1); ones_like = None
|
||||
neg_1: "f32[]" = torch.ops.aten.neg.default(div_1); div_1 = None
|
||||
expand: "f32[3]" = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None
|
||||
mul_1: "f32[3]" = torch.ops.aten.mul.Tensor(expand, detach_5); expand = detach_5 = None
|
||||
_log_softmax_backward_data: "f32[3]" = torch.ops.aten._log_softmax_backward_data.default(mul_1, detach_12, 0, torch.float32); mul_1 = detach_12 = None
|
||||
_softmax_backward_data: "f32[3]" = torch.ops.aten._softmax_backward_data.default(_log_softmax_backward_data, detach_1, 0, torch.float32); _log_softmax_backward_data = detach_1 = None
|
||||
view_2: "f32[1, 3]" = torch.ops.aten.view.default(_softmax_backward_data, [1, 3]); _softmax_backward_data = None
|
||||
t_1: "f32[3, 1]" = torch.ops.aten.t.default(view_2)
|
||||
mm: "f32[3, 3]" = torch.ops.aten.mm.default(t_1, view); t_1 = view = None
|
||||
t_2: "f32[3, 3]" = torch.ops.aten.t.default(mm); mm = None
|
||||
sum_2: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None
|
||||
view_3: "f32[3]" = torch.ops.aten.view.default(sum_2, [3]); sum_2 = None
|
||||
t_3: "f32[3, 3]" = torch.ops.aten.t.default(t_2); t_2 = None
|
||||
return (div, t_3, view_3)
|
||||
|
||||
Graph signature: ExportGraphSignature(
|
||||
input_specs=[
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='linear.weight', persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg1_1'), target='linear.bias', persistent=None),
|
||||
InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='arg2_1'), target='lifted_tensor_0', persistent=None),
|
||||
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None, persistent=None)
|
||||
],
|
||||
output_specs=[
|
||||
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='div'), target=None),
|
||||
OutputSpec(kind=<OutputKind.GRADIENT_TO_PARAMETER: 4>, arg=TensorArgument(name='t_3'), target='linear.weight'),
|
||||
OutputSpec(kind=<OutputKind.GRADIENT_TO_PARAMETER: 4>, arg=TensorArgument(name='view_3'), target='linear.bias')
|
||||
]
|
||||
)
|
||||
Range constraints: {}
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -77,6 +77,7 @@ from torch.testing._internal.common_utils import (
|
||||
subtest,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
@ -2341,8 +2342,7 @@ class TestJac(VmapTearDownMixin, TestCase):
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/127036
|
||||
# it won't fail as jacrev/jacfwd were not inlined (see #128255)
|
||||
# @xfailIfTorchDynamo
|
||||
@xfailIfTorchDynamo
|
||||
@parametrize("_preallocate_and_copy", (True, False))
|
||||
def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy):
|
||||
# With chunk_size=1, we shouldn't `vmap` and hence not be limited
|
||||
|
||||
@ -1767,6 +1767,33 @@ TORCH_LIBRARY(test_autograd_cpp_node_data_dependent, m) {
|
||||
out = compiled_fn(activations)
|
||||
self.assertTrue(len(activations) == 0)
|
||||
|
||||
def test_callback_graph_break_throws_error(self):
|
||||
called = [0]
|
||||
|
||||
def callback_final():
|
||||
called[0] += 1
|
||||
|
||||
class MyFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad):
|
||||
torch.autograd.Variable._execution_engine.queue_callback(callback_final)
|
||||
torch._dynamo.graph_break()
|
||||
return grad
|
||||
|
||||
a = torch.rand((3, 3), requires_grad=True)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
"only supported when Compiled Autograd is enabled with fullgraph=True",
|
||||
):
|
||||
with compiled_autograd.enable(make_compiler_fn(fullgraph=False)):
|
||||
b = MyFunc.apply(a)
|
||||
b.sum().backward()
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "requires cuda")
|
||||
def test_cudagraphs_cpu_division(self):
|
||||
from torch._dynamo.testing import reduce_to_scalar_loss
|
||||
@ -2177,7 +2204,6 @@ known_failing_tests = {
|
||||
"test_autograd_multiple_views_python", # torch._dynamo.exc.Unsupported: call_function args: TensorVariable(
|
||||
"test_autograd_node_isinstance", # torch._dynamo.exc.Unsupported: 'inline in skipfiles: TestCase.assertIsInstance
|
||||
"test_autograd_simple_views_python", # torch._dynamo.exc.TorchRuntimeError: Failed running call_function
|
||||
"test_callback_adds_callback", # torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable
|
||||
"test_callback_propagates_errors_from_device_thread", # AssertionError: "blah" does not match "call_method
|
||||
"test_custom_autograd_no_early_free", # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
|
||||
"test_custom_function_cycle", # torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [] {}
|
||||
|
||||
@ -235,6 +235,7 @@ if RUN_CPU:
|
||||
BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()),
|
||||
BaseTest("test_linear1"),
|
||||
BaseTest("test_linear2"),
|
||||
BaseTest("test_polar"),
|
||||
BaseTest(
|
||||
"test_linear_binary",
|
||||
"",
|
||||
@ -255,7 +256,8 @@ if RUN_CPU:
|
||||
BaseTest("test_multihead_attention", "cpu", test_cpu_repro.CPUReproTests()),
|
||||
BaseTest(
|
||||
"test_multi_threading",
|
||||
code_string_count={"py::gil_scoped_release release;": 1},
|
||||
# Two threads compile, so we expect the output code to be printed twice.
|
||||
code_string_count={"py::gil_scoped_release release;": 2},
|
||||
),
|
||||
BaseTest("test_profiler_mark_wrapper_call"),
|
||||
BaseTest(
|
||||
|
||||
@ -1920,6 +1920,8 @@ class CPUReproTests(TestCase):
|
||||
FileCheck().check(_target_code_check).run(code)
|
||||
if _target_code_check_not:
|
||||
FileCheck().check_not(_target_code_check_not).run(code)
|
||||
# Verify that the output isn't empty
|
||||
FileCheck().check("Output code:").run(code)
|
||||
|
||||
self.assertEqual(
|
||||
_fn(*_inps),
|
||||
@ -1933,10 +1935,16 @@ class CPUReproTests(TestCase):
|
||||
_internal_check(fn, inps, "aten.scatter_reduce_")
|
||||
|
||||
if "ATen parallel backend: OpenMP" in torch.__config__.parallel_info():
|
||||
# Fix https://github.com/pytorch/pytorch/issues/118518
|
||||
# which fails to change thread number with native thread pool
|
||||
with set_num_threads(1):
|
||||
_internal_check(fn, inps, _target_code_check_not="aten.scatter_reduce_")
|
||||
# When running with a single thread, we expect the aten.scatter will go
|
||||
# into the cpp backend codegen instead of a fallback to aten.scatter_reduce_.
|
||||
# Avoid the inductor cache so we don't serve an entry compiled above.
|
||||
with config.patch(
|
||||
{"fx_graph_cache": False, "fx_graph_remote_cache": False}
|
||||
):
|
||||
_internal_check(
|
||||
fn, inps, _target_code_check_not="aten.scatter_reduce_"
|
||||
)
|
||||
|
||||
with config.patch({"cpp.dynamic_threads": True}), set_num_threads(1):
|
||||
_internal_check(fn, inps, "aten.scatter_reduce_")
|
||||
|
||||
@ -442,7 +442,15 @@ class TestPatternMatcher(TestCase):
|
||||
.sub(8),
|
||||
)
|
||||
|
||||
args_list = [
|
||||
def check_uint4x2_mixed_mm(args, expect_mixed_mm):
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
ref = fn(*args)
|
||||
test, (code,) = run_and_get_code(torch.compile(fn), *args)
|
||||
torch.testing.assert_close(ref, test)
|
||||
self.assertEqual("uint4x2_mixed_mm" in code, expect_mixed_mm)
|
||||
|
||||
args_expect_mixed_mm = [
|
||||
(
|
||||
torch.randn(8, 8, device="cuda"),
|
||||
torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"),
|
||||
@ -454,6 +462,13 @@ class TestPatternMatcher(TestCase):
|
||||
.contiguous()
|
||||
.t(),
|
||||
),
|
||||
]
|
||||
|
||||
for args in args_expect_mixed_mm:
|
||||
check_uint4x2_mixed_mm(args, True)
|
||||
|
||||
# mixed mm is only enabled when casting from a lower-bitwidth dtype to a higher one
|
||||
args_expect_no_mixed_mm = [
|
||||
(
|
||||
torch.randn(8, 8, device="cuda"),
|
||||
torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"),
|
||||
@ -464,13 +479,8 @@ class TestPatternMatcher(TestCase):
|
||||
),
|
||||
]
|
||||
|
||||
for args in args_list:
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
ref = fn(*args)
|
||||
test, (code,) = run_and_get_code(torch.compile(fn), *args)
|
||||
torch.testing.assert_close(ref, test)
|
||||
self.assertTrue("uint4x2_mixed_mm" in code)
|
||||
for args in args_expect_no_mixed_mm:
|
||||
check_uint4x2_mixed_mm(args, False)
|
||||
|
||||
@unittest.skipIf(not SM80OrLater, "need sm_80")
|
||||
@inductor_config.patch(use_mixed_mm=True)
|
||||
|
||||
@ -158,10 +158,10 @@ class DynamoProfilerTests(torch._inductor.test_case.TestCase):
|
||||
|
||||
hooks_called = {"enter": False, "exit": False}
|
||||
|
||||
def launch_enter_hook(*args):
|
||||
def launch_enter_hook(lazy_dict):
|
||||
hooks_called["enter"] = True
|
||||
|
||||
def launch_exit_hook(*args):
|
||||
def launch_exit_hook(lazy_dict):
|
||||
hooks_called["exit"] = True
|
||||
|
||||
CompiledKernel.launch_enter_hook = launch_enter_hook
|
||||
|
||||
@ -28,7 +28,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.aoti_eager
|
||||
import torch.nn as nn
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.debug_utils import aot_graph_input_parser
|
||||
@ -40,16 +39,14 @@ from torch._dynamo.testing import (
|
||||
skipIfPy312,
|
||||
)
|
||||
from torch._dynamo.utils import ifdynstaticdefault
|
||||
from torch._inductor.aoti_eager import (
|
||||
aoti_compile_with_persistent_cache,
|
||||
aoti_eager_cache_dir,
|
||||
load_aoti_eager_cache,
|
||||
)
|
||||
from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext
|
||||
from torch._inductor.fx_passes import pad_mm
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import (
|
||||
add_scheduler_init_hook,
|
||||
aoti_compile_with_persistent_cache,
|
||||
aoti_eager_cache_dir,
|
||||
load_aoti_eager_cache,
|
||||
run_and_get_code,
|
||||
run_and_get_cpp_code,
|
||||
run_and_get_triton_code,
|
||||
@ -772,7 +769,7 @@ class CommonTemplate:
|
||||
)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_aoti_eager_support_out(self):
|
||||
def test_eager_aoti_support_out(self):
|
||||
ns = "aten"
|
||||
op_name = "clamp"
|
||||
dispatch_key = "CPU"
|
||||
@ -824,44 +821,7 @@ class CommonTemplate:
|
||||
self.assertEqual(ref_out_tensor1, res_out_tensor1)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_aoti_eager_support_str(self):
|
||||
ns = "aten"
|
||||
op_name = "div"
|
||||
dispatch_key = "CPU"
|
||||
device = "cpu"
|
||||
if self.device.lower() == "cuda":
|
||||
dispatch_key = "CUDA"
|
||||
device = "cuda"
|
||||
|
||||
a = torch.randn(128, dtype=torch.float, device=device)
|
||||
b = torch.randn(128, dtype=torch.float, device=device)
|
||||
rounding_mode_list = ["trunc", "floor"]
|
||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||
# Get ref result from eager
|
||||
ref_value_list = []
|
||||
for rounding_mode in rounding_mode_list:
|
||||
ref_value = getattr(torch.ops.aten, op_name)(
|
||||
a, b, rounding_mode=rounding_mode
|
||||
)
|
||||
ref_value_list.append(ref_value)
|
||||
|
||||
register_ops_with_aoti_compile(
|
||||
ns, [op_name], dispatch_key, torch_compile_op_lib_impl
|
||||
)
|
||||
|
||||
# Invoke the pre-compiled kernel and get result.
|
||||
res_value_list = []
|
||||
for rounding_mode in rounding_mode_list:
|
||||
res_value = getattr(torch.ops.aten, op_name)(
|
||||
a, b, rounding_mode=rounding_mode
|
||||
)
|
||||
res_value_list.append(res_value)
|
||||
|
||||
for ref_value, res_value in zip(ref_value_list, res_value_list):
|
||||
self.assertEqual(ref_value, res_value)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_aoti_eager_cache_hit(self):
|
||||
def test_eager_aoti_cache_hit(self):
|
||||
ns = "aten"
|
||||
op_name = "abs"
|
||||
dispatch_key = "CPU"
|
||||
@ -886,7 +846,7 @@ class CommonTemplate:
|
||||
|
||||
# Patch the aoti_compile_with_persistent_cache as None to ensure no new kernel is generated
|
||||
with mock.patch(
|
||||
"torch._inductor.aoti_eager.aoti_compile_with_persistent_cache", None
|
||||
"torch._inductor.utils.aoti_compile_with_persistent_cache", None
|
||||
):
|
||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||
# Get ref result from eager
|
||||
@ -902,7 +862,7 @@ class CommonTemplate:
|
||||
self.assertEqual(ref_value, res_value)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_aoti_eager_with_persistent_cache(self):
|
||||
def test_eager_aoti_with_persistent_cache(self):
|
||||
def fn(a):
|
||||
return torch.abs(a)
|
||||
|
||||
@ -946,7 +906,7 @@ class CommonTemplate:
|
||||
self.assertTrue(kernel_lib_path in kernel_libs_abs_path)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_aoti_eager_with_scalar(self):
|
||||
def test_eager_aoti_with_scalar(self):
|
||||
namespace_name = "aten"
|
||||
op_name = "add"
|
||||
op_overload_name = "Tensor"
|
||||
@ -982,18 +942,18 @@ class CommonTemplate:
|
||||
self.assertTrue(isinstance(op_info, dict))
|
||||
self.assertTrue("meta_info" in op_info)
|
||||
self.assertTrue(len(op_info["meta_info"]) == 3)
|
||||
# Scalar Tensor
|
||||
self.assertTrue("scalar_value" not in op_info["meta_info"][0])
|
||||
self.assertTrue(op_info["meta_info"][0]["sizes"] == [])
|
||||
self.assertTrue(op_info["meta_info"][0]["strides"] == [])
|
||||
# Scalar Tensor
|
||||
self.assertTrue("scalar_value" not in op_info["meta_info"][1])
|
||||
self.assertTrue("scalar_value" not in op_info["meta_info"][0])
|
||||
self.assertTrue(op_info["meta_info"][1]["sizes"] == [])
|
||||
self.assertTrue(op_info["meta_info"][1]["strides"] == [])
|
||||
# Scalar Tensor
|
||||
self.assertTrue("scalar_value" not in op_info["meta_info"][1])
|
||||
self.assertTrue(op_info["meta_info"][2]["sizes"] == [])
|
||||
self.assertTrue(op_info["meta_info"][2]["strides"] == [])
|
||||
# Scalar
|
||||
self.assertTrue("scalar_value" in op_info["meta_info"][2])
|
||||
self.assertTrue("sizes" not in op_info["meta_info"][2])
|
||||
self.assertTrue("strides" not in op_info["meta_info"][2])
|
||||
|
||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||
a = torch.randn(128, device=device)
|
||||
@ -1016,7 +976,7 @@ class CommonTemplate:
|
||||
self.assertEqual(ref_values, res_values)
|
||||
|
||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||
def test_aoti_eager_override_registration(self):
|
||||
def test_eager_aoti_override_registration(self):
|
||||
namespace_name = "aten"
|
||||
dispatch_key = "CPU"
|
||||
device = torch.device("cpu")
|
||||
@ -4697,6 +4657,16 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (x,))
|
||||
|
||||
def test_polar(self):
|
||||
def fn(dist, angle):
|
||||
return torch.polar(dist, angle)
|
||||
|
||||
inp = (
|
||||
torch.tensor([1, 2], dtype=torch.float64),
|
||||
torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64),
|
||||
)
|
||||
self.common(fn, (*inp,))
|
||||
|
||||
def test_cauchy(self):
|
||||
def fn(x, y):
|
||||
return torch.sum(1 / (torch.unsqueeze(x, -1) - y))
|
||||
@ -10167,7 +10137,8 @@ class CommonTemplate:
|
||||
self.assertEqual(rot.grad, rot_e.grad)
|
||||
self.assertEqual(trans.grad, trans_e.grad)
|
||||
|
||||
@config.patch({"fx_graph_cache": False})
|
||||
# If we serve from the cache, the init hook isn't called
|
||||
@config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False})
|
||||
def test_inner_fn_str_and_stride(self):
|
||||
def f(x):
|
||||
x = x + 1
|
||||
|
||||
@ -237,6 +237,7 @@ test_failures = {
|
||||
"test_pointwise_hermite_polynomial_he_dynamic_shapes": TestFailure(("cuda", "xpu")),
|
||||
"test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda", "xpu")),
|
||||
"test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda", "xpu")),
|
||||
"test_polar_dynamic_shapes": TestFailure(("cpu", "cuda"), is_skip=True),
|
||||
"test_randn_generator_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_single_elem_dynamic_shapes": TestFailure(("cpu",)),
|
||||
|
||||
@ -411,7 +411,6 @@ inductor_one_sample = {
|
||||
"_segment_reduce.lengths": {f16},
|
||||
"_segment_reduce.offsets": {f16},
|
||||
"addmv": {f16},
|
||||
"argsort": {b8, f16, f32, f64, i32, i64},
|
||||
"as_strided.partial_views": {f16},
|
||||
"corrcoef": {f16},
|
||||
"diff": {f16},
|
||||
@ -426,11 +425,7 @@ inductor_one_sample = {
|
||||
"logspace": {f16},
|
||||
"logspace.tensor_overload": {f16, f32, f64, i32, i64},
|
||||
"masked_logsumexp": {i64},
|
||||
"max.binary": {b8},
|
||||
"max_pool2d_with_indices_backward": {f16, f32, f64},
|
||||
"maximum": {b8},
|
||||
"min.binary": {b8},
|
||||
"minimum": {b8},
|
||||
"new_empty_strided": {f16},
|
||||
"nn.functional.adaptive_avg_pool3d": {f16},
|
||||
"nn.functional.adaptive_max_pool1d": {f16, f32},
|
||||
|
||||
@ -14,6 +14,7 @@ from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
skipIfXpu,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
@ -214,6 +215,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
||||
# Expect 3 block pointers: 2 inputs one output
|
||||
self.run_and_compare(foo, x, y, expected_num_block_pointers=3)
|
||||
|
||||
@skipIfXpu
|
||||
@parametrize(
|
||||
"view_size,num_block_pointers,num_triton_kernels",
|
||||
[
|
||||
|
||||
@ -911,51 +911,6 @@ class TestTracer(JitTestCase):
|
||||
self.assertEqual(len(list(g.inputs())), 2)
|
||||
FileCheck().check("mul").check("add").run(str(g))
|
||||
|
||||
def test_trace_c10_ops(self):
|
||||
try:
|
||||
_ = torch.ops._caffe2.GenerateProposals
|
||||
except AttributeError:
|
||||
self.skipTest("Skip the test since c2 ops are not registered.")
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, scores, bbox_deltas, im_info, anchors):
|
||||
a, b = torch.ops._caffe2.GenerateProposals(
|
||||
(scores),
|
||||
(bbox_deltas),
|
||||
(im_info),
|
||||
(anchors),
|
||||
2.0,
|
||||
6000,
|
||||
300,
|
||||
0.7,
|
||||
16,
|
||||
True,
|
||||
-90,
|
||||
90,
|
||||
1.0,
|
||||
True,
|
||||
)
|
||||
return a, b
|
||||
|
||||
model = MyModel()
|
||||
A = 4
|
||||
H = 10
|
||||
W = 8
|
||||
img_count = 3
|
||||
scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
|
||||
bbox_deltas = torch.linspace(
|
||||
0, 10, steps=img_count * 4 * A * H * W, dtype=torch.float32
|
||||
)
|
||||
bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
|
||||
im_info = torch.ones(img_count, 3, dtype=torch.float32)
|
||||
anchors = torch.ones(A, 4, dtype=torch.float32)
|
||||
inputs = (scores, bbox_deltas, im_info, anchors)
|
||||
traced_model = torch.jit.trace(model, inputs)
|
||||
self.assertEqual(traced_model(*inputs), model(*inputs))
|
||||
self.assertExportImportModule(
|
||||
traced_model, (scores, bbox_deltas, im_info, anchors)
|
||||
)
|
||||
|
||||
def run_ge_tests(self, optimize, use_cuda):
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
with torch.jit.optimized_execution(optimize):
|
||||
|
||||
@ -340,8 +340,8 @@ def xfail(error_message: str, reason: Optional[str] = None):
|
||||
|
||||
|
||||
# skips tests for opset_versions listed in unsupported_opset_versions.
|
||||
# if the caffe2 test cannot be run for a specific version, add this wrapper
|
||||
# (for example, an op was modified but the change is not supported in caffe2)
|
||||
# if the PyTorch test cannot be run for a specific version, add this wrapper
|
||||
# (for example, an op was modified but the change is not supported in PyTorch)
|
||||
def skipIfUnsupportedOpsetVersion(unsupported_opset_versions):
|
||||
def skip_dec(func):
|
||||
@functools.wraps(func)
|
||||
|
||||
@ -873,33 +873,6 @@ class TestOperators(common_utils.TestCase):
|
||||
x = torch.randn(2, 3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
|
||||
|
||||
# Github Issue: https://github.com/pytorch/pytorch/issues/71095
|
||||
# def test_c2_op(self):
|
||||
# class MyModel(torch.nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
#
|
||||
# def forward(self, scores, bbox_deltas, im_info, anchors):
|
||||
# a, b = torch.ops._caffe2.GenerateProposals(
|
||||
# (scores), (bbox_deltas), (im_info), (anchors),
|
||||
# 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
|
||||
# )
|
||||
# return a, b
|
||||
#
|
||||
# model = MyModel()
|
||||
# A = 4
|
||||
# H = 10
|
||||
# W = 8
|
||||
# img_count = 3
|
||||
# scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
|
||||
# bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
|
||||
# dtype=torch.float32)
|
||||
# bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
|
||||
# im_info = torch.ones(img_count, 3, dtype=torch.float32)
|
||||
# anchors = torch.ones(A, 4, dtype=torch.float32)
|
||||
# inputs = (scores, bbox_deltas, im_info, anchors)
|
||||
# self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0})
|
||||
|
||||
def test_dict(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x_in):
|
||||
|
||||
@ -1358,6 +1358,8 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
iter = graph.nodes()
|
||||
self.assertEqual(next(iter).kind(), "custom_namespace::custom_op")
|
||||
|
||||
# gelu is exported as onnx::Gelu for opset >= 20
|
||||
@skipIfUnsupportedMaxOpsetVersion(19)
|
||||
def test_custom_opsets_gelu(self):
|
||||
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9)
|
||||
|
||||
@ -1382,6 +1384,8 @@ class TestUtilityFuns(_BaseTestCase):
|
||||
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
|
||||
self.assertEqual(graph.opset_import[1].version, 1)
|
||||
|
||||
# gelu is exported as onnx::Gelu for opset >= 20
|
||||
@skipIfUnsupportedMaxOpsetVersion(19)
|
||||
def test_register_aten_custom_op_symbolic(self):
|
||||
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user