Compare commits

..

1 Commits

Author SHA1 Message Date
fe3ee2e446 Revert "[RELAND] Always build USE_DISTRIBUTED (#160449) and Make distributed modules importable even when backend not built (#159889) (#162594)"
This reverts commit 09cb34c1dce8fe1b880bbf3115d8ddad3401d871.
2025-09-24 18:18:00 -07:00
296 changed files with 2719 additions and 6623 deletions

View File

@ -69,8 +69,7 @@ RUN bash ./install_cuda.sh 13.0
ENV DESIRED_CUDA=13.0
FROM ${ROCM_IMAGE} as rocm
ARG PYTORCH_ROCM_ARCH
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
ENV PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
ADD ./common/install_mkl.sh install_mkl.sh
RUN bash ./install_mkl.sh && rm install_mkl.sh
ENV MKLROOT /opt/intel

View File

@ -36,12 +36,6 @@ case ${DOCKER_TAG_PREFIX} in
;;
rocm*)
BASE_TARGET=rocm
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950 conditionally starting in ROCm 7.0
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950"
fi
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
;;
*)
echo "ERROR: Unknown docker tag ${DOCKER_TAG_PREFIX}"

View File

@ -40,16 +40,12 @@ case ${DOCKER_TAG_PREFIX} in
;;
rocm*)
# we want the patch version of 6.4 instead
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
if [[ $(ver $GPU_ARCH_VERSION) -eq $(ver 6.4) ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
fi
BASE_TARGET=rocm
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950 conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950"
fi
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
;;
*)

View File

@ -82,7 +82,7 @@ case ${image} in
;;
manylinux2_28-builder:rocm*)
# we want the patch version of 6.4 instead
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
if [[ $(ver $GPU_ARCH_VERSION) -eq $(ver 6.4) ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
fi
TARGET=rocm_final
@ -90,10 +90,6 @@ case ${image} in
DEVTOOLSET_VERSION="11"
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950 conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950"
fi
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
;;
manylinux2_28-builder:xpu)

View File

@ -1,11 +1,11 @@
SHELL=/usr/bin/env bash
DOCKER_CMD ?= docker
DESIRED_ROCM ?= 7.0
DESIRED_ROCM ?= 6.4
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
PACKAGE_NAME = magma-rocm
# inherit this from underlying docker image, do not pass this env var to docker
#PYTORCH_ROCM_ARCH ?= gfx900;gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201
#PYTORCH_ROCM_ARCH ?= gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201
DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
-v $(shell git rev-parse --show-toplevel)/.ci:/builder \
@ -16,7 +16,6 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
magma-rocm/build_magma.sh
.PHONY: all
all: magma-rocm70
all: magma-rocm64
all: magma-rocm63
@ -25,11 +24,6 @@ clean:
$(RM) -r magma-*
$(RM) -r output
.PHONY: magma-rocm70
magma-rocm70: DESIRED_ROCM := 7.0
magma-rocm70:
$(DOCKER_RUN)
.PHONY: magma-rocm64
magma-rocm64: DESIRED_ROCM := 6.4
magma-rocm64:

View File

@ -6,8 +6,8 @@ set -eou pipefail
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
# https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
# Version 2.7.2 + ROCm related updates
MAGMA_VERSION=a1625ff4d9bc362906bd01f805dbbe12612953f6
# Folders for the build
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
# Fetch magma sources and verify checksum
pushd ${PACKAGE_DIR}
git clone https://github.com/jeffdaily/magma
git clone https://bitbucket.org/icl/magma.git
pushd magma
git checkout ${MAGMA_VERSION}
popd

View File

@ -58,7 +58,7 @@ time python tools/setup_helpers/generate_code.py \
# Build the docs
pushd docs/cpp
time make VERBOSE=1 html
time make VERBOSE=1 html -j
popd
popd

View File

@ -22,9 +22,6 @@ self-hosted-runner:
- linux.arm64.m7g.4xlarge
- linux.arm64.m7g.4xlarge.ephemeral
- linux.arm64.r7g.12xlarge.memory
- linux.aws.h100
- linux.aws.h100.4
- linux.aws.h100.8
- linux.4xlarge.nvidia.gpu
- linux.8xlarge.nvidia.gpu
- linux.16xlarge.nvidia.gpu

View File

@ -1 +1 @@
9fe4c2bdb9859c14ad7f7479e1db7e01083bada3
1983609239caaab24ab1ed2bfa2aa92e8c76c1b1

View File

@ -67,7 +67,7 @@ jobs:
# an OOM issue when running the job, so this upgrades the runner from 4xlarge
# to the next available tier of 12xlarge. So much memory just to generate cpp
# doc
runner: ${{ inputs.runner_prefix }}linux.12xlarge.memory
runner: ${{ inputs.runner_prefix }}linux.12xlarge
# TODO: Nightly cpp docs take longer and longer to finish (more than 3h now)
# Let's try to figure out how this can be improved
timeout-minutes: 360

View File

@ -36,7 +36,7 @@ jobs:
runs-on: linux.9xlarge.ephemeral
strategy:
matrix:
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.3", "rocm6.4", "rocm7.0", "cpu"]
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.3", "rocm6.4", "cpu"]
steps:
- name: Build docker image
uses: pytorch/pytorch/.github/actions/binary-docker-build@main

View File

@ -54,7 +54,6 @@ jobs:
{ tag: "cuda12.6" },
{ tag: "rocm6.3" },
{ tag: "rocm6.4" },
{ tag: "rocm7.0" },
{ tag: "cpu" },
]
steps:

View File

@ -34,7 +34,7 @@ jobs:
id-token: write
strategy:
matrix:
rocm_version: ["70", "64", "63"]
rocm_version: ["64", "63"]
steps:
- name: Checkout PyTorch
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

View File

@ -54,7 +54,6 @@ jobs:
{ name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm6.3", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm6.4", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28_aarch64-builder", tag: "cpu-aarch64", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinuxcxx11-abi-builder", tag: "cpu-cxx11-abi", runner: "linux.9xlarge.ephemeral" },

View File

@ -35,7 +35,6 @@ jobs:
contents: write
outputs:
pt_release_name: ${{ steps.release_name.outputs.pt_release_name }}
pt_pep517_release_name: ${{ steps.release_name.outputs.pt_pep517_release_name }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
@ -54,12 +53,8 @@ jobs:
tag_or_branch="${tag_or_branch#refs/heads/}"
# replace directory separators with _ in branch name
tag_or_branch="${tag_or_branch//\//_}"
torch_version="$(python -c 'from tools.generate_torch_version import get_torch_version; print(get_torch_version())')"
{
echo "PT_RELEASE_NAME=pytorch-$tag_or_branch";
echo "PT_RELEASE_FILE=pytorch-$tag_or_branch.tar.gz";
echo "PT_PEP517_RELEASE_FILE=torch-${torch_version}.tar.gz";
} >> "$GITHUB_ENV"
echo "PT_RELEASE_NAME=pytorch-$tag_or_branch" >> "$GITHUB_ENV"
echo "PT_RELEASE_FILE=pytorch-$tag_or_branch.tar.gz" >> "$GITHUB_ENV"
- name: Checkout optional submodules
run: python3 tools/optional_submodules.py
- name: Copy docs requirements for inclusion
@ -69,47 +64,30 @@ jobs:
cp .ci/docker/requirements-docs.txt docs/requirements.txt
- name: Create source distribution
run: |
# Create new folder with specified name so extracting the archive yields that
rm -rf "/tmp/$PT_RELEASE_NAME"
cp -r "$PWD" "/tmp/$PT_RELEASE_NAME"
mv "/tmp/$PT_RELEASE_NAME" .
# Cleanup
rm -rf "$PT_RELEASE_NAME"/{.circleci,.ci}
find "$PT_RELEASE_NAME" -name '.git*' -exec rm -rv {} \; || true
# Create archive
tar -czf "$PT_RELEASE_FILE" "$PT_RELEASE_NAME"
echo "Created source archive $PT_RELEASE_FILE with content: $(ls -a "$PT_RELEASE_NAME")"
- name: Create PEP 517 compatible source distribution
run: |
pip install build==1.2.2.post1 || exit 1
python -m build --sdist || exit 1
cd dist || exit 1
# Create new folder with specified name so extracting the archive yields that
rm -rf "/tmp/$PT_RELEASE_NAME"
cp -r "$PWD" "/tmp/$PT_RELEASE_NAME"
mv "/tmp/$PT_RELEASE_NAME" .
# Cleanup
rm -rf "$PT_RELEASE_NAME"/{.circleci,.ci}
find "$PT_RELEASE_NAME" -name '.git*' -exec rm -rv {} \; || true
# Create archive
tar -czf "$PT_RELEASE_FILE" "$PT_RELEASE_NAME"
echo "Created source archive $PT_RELEASE_FILE with content: $(ls -a "$PT_RELEASE_NAME")"
- name: Upload source distribution for release
if: ${{ github.event_name == 'release' }}
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # v2.2.2
with:
files: |
${{ env.PT_RELEASE_FILE }}
${{ env.PT_PEP517_RELEASE_FILE }}
- name: Upload source distribution to GHA artifacts # for release tags
files: ${{env.PT_RELEASE_FILE}}
- name: Upload source distribution to GHA artifacts for release tags
if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }}
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
with:
name: ${{ env.PT_RELEASE_FILE }}
path: ${{ env.PT_RELEASE_FILE }}
- name: Upload PEP 517 source distribution to GHA artifacts # for release tags
if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }}
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
with:
name: ${{ env.PT_PEP517_RELEASE_FILE }}
path: dist/${{ env.PT_PEP517_RELEASE_FILE }}
- name: Set output
id: release_name
run: |
{
echo "pt_release_name=${{ env.PT_RELEASE_FILE }}";
echo "pt_pep517_release_name=${{ env.PT_PEP517_RELEASE_FILE }}";
} >> "${GITHUB_OUTPUT}"
run: echo "pt_release_name=${{ env.PT_RELEASE_NAME }}.tar.gz" >> "${GITHUB_OUTPUT}"
upload_source_code_to_s3:
if: ${{ github.repository == 'pytorch/pytorch' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }}
@ -125,9 +103,6 @@ jobs:
- uses: actions/download-artifact@65a9edc5881444af0b9093a5e628f2fe47ea3b2e # v4.1.7
with:
name: ${{ needs.release.outputs.pt_release_name }}
- uses: actions/download-artifact@65a9edc5881444af0b9093a5e628f2fe47ea3b2e # v4.1.7
with:
name: ${{ needs.release.outputs.pt_pep517_release_name }}
- name: Configure AWS credentials(PyTorch account)
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
@ -138,9 +113,7 @@ jobs:
s3-bucket: pytorch
s3-prefix: source_code/test
if-no-files-found: warn
path: |
${{ needs.release.outputs.pt_release_name }}
${{ needs.release.outputs.pt_pep517_release_name }}
path: ${{ needs.release.outputs.pt_release_name }}
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}

View File

@ -127,8 +127,6 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
# More memory is needed to build with asan
runner: linux.2xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.10-clang18-asan
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan

View File

@ -140,8 +140,6 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
# More memory is needed to build with asan
runner: linux.2xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.10-clang18-asan
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan

1
.gitignore vendored
View File

@ -82,7 +82,6 @@ torch/return_types.pyi
torch/nn/functional.pyi
torch/utils/data/datapipes/datapipe.pyi
torch/csrc/autograd/generated/*
torch/csrc/functionalization/generated/*
torch/csrc/lazy/generated/*.[!m]*
torch_compile_debug/
# Listed manually because some files in this directory are not generated

View File

@ -1453,7 +1453,7 @@ init_command = [
'--dry-run={{DRYRUN}}',
'usort==1.0.8.post1',
'isort==6.0.1',
'ruff==0.13.1', # sync with RUFF
'ruff==0.12.9', # sync with RUFF
]
is_formatter = true
@ -1587,7 +1587,7 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'ruff==0.13.1', # sync with PYFMT
'ruff==0.12.9', # sync with PYFMT
]
is_formatter = true

View File

@ -91,8 +91,6 @@ generated_cpu_cpp = [
"aten/src/ATen/NativeMetaFunctions.h",
"aten/src/ATen/RegistrationDeclarations.h",
"aten/src/ATen/VmapGeneratedPlumbing.h",
"aten/src/ATen/ViewMetaClasses.h",
"aten/src/ATen/ViewMetaClasses.cpp",
"aten/src/ATen/core/aten_interned_strings.h",
"aten/src/ATen/core/enum_tag.h",
"aten/src/ATen/core/TensorBody.h",
@ -1077,7 +1075,6 @@ test_suite(
"aten/src/ATen/templates/LazyNonNativeIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
"aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp",
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"aten/src/ATen/native/ts_native_functions.yaml",

View File

@ -1,61 +1,20 @@
# Reference: https://setuptools.pypa.io/en/latest/userguide/miscellaneous.html
# Include individual top-level files
include CITATION.cff
include CODEOWNERS
include Dockerfile
include LICENSE
include MANIFEST.in
include Makefile
include NOTICE
include .bc-linter.yml
include .clang-format .clang-tidy
include .cmakelintrc
include .coveragerc
include .dockerignore
include .editorconfig
include .flake8
include .gdbinit
include .lintrunner.toml
include .lldbinit
include codex_setup.sh
include docker.Makefile
include pyrefly.toml
include ubsan.supp
# Include bazel and BUCK related files
include BUILD.bazel BUCK.oss
include WORKSPACE
include *.bzl
include .bazelignore .bazelrc .bazelversion
# Include general configuration files
include *.ini
# Include important top-level information
include *.md
# Include technical text files at the moment, comprises
# version.txt, CMakeLists.txt, requirements.txt
include *.txt
# Include ctags configuration
include .ctags.d/*.ctags
# Include subfolders completely
graft .devcontainer
graft .vscode
# Include source files in SDist
include CMakeLists.txt
include *.bzl *.bazel .bazel* BUILD *.BUILD BUILD.* WORKSPACE
include BUCK BUCK.*
include requirements*.txt
include version.txt
include [Mm]akefile *.[Mm]akefile [Mm]akefile.*
include [Dd]ockerfile *.[Dd]ockerfile [Dd]ockerfile.* .dockerignore
graft android
graft aten
graft benchmarks
graft binaries
graft c10
graft caffe2
graft cmake
graft docs
graft functorch
graft ios
graft mypy_plugins
graft scripts
graft test
graft third_party
graft tools
graft torch
@ -63,37 +22,29 @@ graft torchgen
# FIXME: torch-xla build during codegen will fail if include this file in wheel
exclude torchgen/BUILD.bazel
# The following exclusions omit parts from third-party dependencies that
# contain invalid symlinks[1] and that are not needed for pytorch, such as
# bindings for unused languages
prune third_party/flatbuffers/java
prune third_party/flatbuffers/kotlin
prune third_party/ittapi/rust
prune third_party/nccl/pkg/debian
prune third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-*
# The following document is also an invalid symlink[1] and superfluous
exclude third_party/flatbuffers/docs/source/CONTRIBUTING.md
# Omit autogenerated code
prune torchgen/packaged
# Omit caches, compiled, and scm related content
prune */__pycache__
prune **/.github
prune **/.gitlab
global-exclude *.o *.obj *.so *.dylib *.a *.pxd *.dll *.lib
global-exclude *.py[cod] *.swp *~
global-exclude .git .git-blame-ignore-revs .gitattributes .gitignore .gitmodules
global-exclude .gitlab-ci.yml
# Misc files and directories in SDist
include *.md
include CITATION.cff
include LICENSE NOTICE
include mypy*.ini
graft benchmarks
graft docs
graft mypy_plugins
graft scripts
# Misc files needed for custom setuptools command
include .gitignore
include .gitmodules
# [1] Invalid symlinks for the purposes of Python source distributions are,
# according to the source distribution format[2] links pointing outside the
# destination directory or links with a `..` component, which is those of
# concern here.
# Include test suites in SDist
graft test
include pytest.ini
include .coveragerc
# [2] https://packaging.python.org/en/latest/specifications/source-distribution-format/#source-distribution-archive-features
# Prune generated/compiled files
prune torchgen/packaged
prune */__pycache__
global-exclude *.o *.obj *.so *.a *.dylib *.pxd *.dll *.lib *.py[cod]
prune */.git
global-exclude .git *~ *.swp

View File

@ -468,7 +468,7 @@ inline Tensor _sum_to(
// if we assume no reduction due to unbacked we ensure that at runtime.
TORCH_MAYBE_SYM_CHECK(
sym_eq(shape[i - leading_dims], sizes[i]),
"non-reduction path was assumed due to unbacked symbols expected those two sizes to be the same:",
"non-reduction path was assumed due to unabcked symbols expected those two sizes to be the same:",
shape[i - leading_dims],
", ",
sizes[i])

View File

@ -9,6 +9,11 @@
namespace at::functionalization {
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
}
// Note [Functionalization: Alias Removal Part 2]
// See Note [Functionalization: Alias Removal] for more details.
// This function applies a single update from one of the views to the StorageImpl.
@ -37,12 +42,12 @@ namespace at::functionalization {
static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
at::Tensor t = update.new_val;
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
if (update.view_metas.empty()) { return t; }
if (update.view_metas.empty()) return t;
std::vector<at::Tensor> tmp_values({base});
tmp_values.reserve(update.view_metas.size());
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back());
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
// All of these ops require additional information to recover the sizes of the original tensor.
// If need to, we could probably apply this optimization and only bother computing tmp_values
@ -50,8 +55,9 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co
tmp_values.push_back(std::move(next_view));
}
for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
int64_t out_idx = update.view_metas[i].out_index;
// Each view inverse is implemented in ViewInverses.cpp.
t = update.view_metas[i]->reverse(tmp_values[i], t);
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
}
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
return t;
@ -105,13 +111,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
}
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<std::shared_ptr<ViewMeta>>& metas) {
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
if (metas.size() > 1) {
for (size_t i = 1; i < metas.size(); ++i) {
// Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided,
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided,
"During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
" was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
"so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "

View File

@ -8,89 +8,44 @@ namespace at::functionalization {
// See Note [Functionalization Pass In Core]
enum class InverseReturnMode {
/// Specifies that functional inverses should always return a view.
AlwaysView,
/// Specifies that functional inverses should always return a non-view / copy.
NeverView,
/// Specifies that functional inverses should return a view unless a (copying)
/// scatter
/// inverse exists, in which case that will be used instead.
/// This avoids as_strided() calls that can be difficult for subclasses to
/// handle.
ViewOrScatterInverse,
};
#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \
static const char* name() { \
return #TYPE; \
}
#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \
using SerializableTuple = std::tuple<__VA_ARGS__>
// ViewMeta is a class used by the functionalization pass to navigate between
// a base tensor and a view tensor.
// For example, if I call `b = a.view1(...)`
// the functionalization pass will generate and store a ViewMeta specialization
// for `view1` operation on b that looks like:
// the functionalization pass will generate and store a ViewMeta on b that looks
// like:
//
// struct TORCH_API view1_ViewMeta : public ViewMeta {
// FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta);
// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
// bool /* reapply_views */,
// const std::vector<int64_t>&);
//
// view1_ViewMeta(const SerializableTuple& tpl)
// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
//
// view1_ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
// : ViewMeta(/*has_symbolic_inputs=*/false),
// reapply_views(reapply_views),
// size(size) {}
//
// Tensor forward(const Tensor& base) override {
// return base.view1(...);
// ViewMeta(
// [<captures>](const Tensor& base, int64_t mutated_view_idx) {
// return base.view1(...);
// },
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
// int64_t mutated_view_idx) -> at::Tensor {
// return at::functionalization::impl::view1_inverse(base, mutated_view,
// ...);
// }
//
// Tensor reverse(const Tensor& base, const Tensor& mutated_view) override {
// return at::functionalization::impl::view1_inverse(base, mutated_view,
// ...);
// }
// The forward_fn lambda describes how to replay view1 on a tensor.
//
// SerializableTuple to_serializable_tuple() {
// return std::make_tuple(reapply_views, size);
// }
//
// bool reapply_views;
// std::vector<int64_t> size;
// };
//
// The forward function describes how to replay view1 on a tensor.
//
// The reverse function describes how, given a tensor that is already a view,
// The reverse_fn lambda describes how, given a tensor that is already a view,
// how to get the corresponding base tensor. See Note [Functionalization Pass:
// View Inverses] for details.
//
// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type
// representing the `ViewMeta` instance state. Methods that take in/return such
// a type are used for supporting pickle serialization.
struct ViewMeta {
ViewMeta(
std::function<Tensor(const Tensor&, int64_t)> forward,
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
bool has_symbolic_inputs,
bool is_multi_output = false,
bool is_as_strided = false,
int64_t out_idx = 0)
: out_index(out_idx),
: forward_fn(std::move(forward)),
reverse_fn(std::move(reverse)),
out_index(out_idx),
is_multi_output(is_multi_output),
is_as_strided(is_as_strided),
has_symbolic_inputs(has_symbolic_inputs) {}
virtual ~ViewMeta() = default;
virtual Tensor forward(const Tensor& base) = 0;
virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0;
std::function<Tensor(const Tensor&, int64_t)> forward_fn;
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
// See Note [out_idx in ViewMeta]
int64_t out_index;
@ -102,17 +57,10 @@ struct ViewMeta {
// Tells us if this view operation has any symbolic inputs
bool has_symbolic_inputs;
// Returns a new ViewMeta with the same forward/reverse
// Returns a copy of the current ViewMeta, if out_idx matches the current
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
// functions, but a new out index.
//
// This method should be implemented by those `ViewMeta` that have more than
// one output.
virtual std::shared_ptr<ViewMeta> to_out_index(int64_t out_index) {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"ViewMeta::to_out_index not implemented. ",
"Likely because there's only one output.");
}
ViewMeta to_out_idx(int64_t out_idx);
};
// FunctionalStorageImpl is a subclass of StorageImpl used by the
@ -145,14 +93,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::Tensor new_val;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::vector<std::shared_ptr<ViewMeta>> view_metas;
const std::vector<ViewMeta> view_metas;
};
explicit FunctionalStorageImpl(const Tensor& value);
void add_update(
const Tensor& updated_val,
const std::vector<std::shared_ptr<ViewMeta>>& view_metas);
const std::vector<ViewMeta>& view_metas);
bool apply_updates();
const Tensor& base() {
return base_;

View File

@ -129,19 +129,17 @@ void FunctionalTensorWrapper::freeze_storage() const {
// - view_value: The output tensor that we need to wrap.
// - base: The "base" of the view that `view_value` was generated from.
// See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic.
FunctionalTensorWrapper::FunctionalTensorWrapper(
const Tensor& view_value,
const FunctionalTensorWrapper* base,
const std::shared_ptr<functionalization::ViewMeta>& meta)
: c10::TensorImpl(
c10::DispatchKeySet(DispatchKey::Functionalize),
view_value.dtype(),
base->storage().data_ptr().device()),
value_(view_value),
is_multi_output_view_(
base->is_multi_output_view_ || meta->is_multi_output),
was_storage_changed_(base->was_storage_changed_),
is_symbolic_(base->is_symbolic_) {
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta)
: c10::TensorImpl(
c10::DispatchKeySet(DispatchKey::Functionalize),
view_value.dtype(),
base->storage().data_ptr().device()
),
value_(view_value),
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output),
was_storage_changed_(base->was_storage_changed_),
is_symbolic_(base->is_symbolic_)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_));
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
set_constructor_metadata();
@ -150,10 +148,11 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(
view_metas_ = base->view_metas_; // copy
}
view_metas_.push_back(meta);
maybe_mark_symbolic(meta.get());
maybe_mark_symbolic(meta);
storage_ = base->storage_; // alias this tensor's storage with the base tensor's
}
functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const {
return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl());
}
@ -177,18 +176,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const {
}
// See Note [Functionalization Pass - Inplace View Ops]
void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr<at::functionalization::ViewMeta>& meta) {
void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) {
view_metas_.push_back(meta);
// Manually track the fact that this tensor received a metadata mutation!
has_metadata_mutation_ = true;
// Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation.
maybe_mark_symbolic(meta.get());
maybe_mark_symbolic(meta);
// Note [Functionalization Pass - Inplace View Ops]
// So, these ops are special - they're mutation AND view ops. They get special codegen.
// An example is transpose_, e.g. `a.transpose_()`
// Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas.
at::AutoDispatchSkipFunctionalize guard;
value_ = meta->forward(value_);
value_ = meta.forward_fn(value_, meta.out_index);
TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize));
}
@ -369,8 +368,15 @@ void FunctionalTensorWrapper::sync_() {
regenerate_from_base();
}
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& FunctionalTensorWrapper::view_metas() const {
return view_metas_;
Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) {
auto t = base;
// Reapply views to get the viewed tensor from the base in alias_
for (auto& view_meta: view_metas_) {
t = view_meta.forward_fn(t, view_meta.out_index);
}
return t;
}
void FunctionalTensorWrapper::regenerate_from_base() {
@ -379,7 +385,7 @@ void FunctionalTensorWrapper::regenerate_from_base() {
auto t = storage_impl->base();
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_);
t = apply_view_metas(t);
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
replace_(t, /*from_lazy_regenerate=*/true);
@ -721,11 +727,11 @@ bool isFunctionalTensor(const std::optional<Tensor>& t) {
}
bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
if (t_list.empty()) { return false; }
if (t_list.empty()) return false;
auto functional_count = 0;
for (const auto i : c10::irange(t_list.size())) {
auto const & e= t_list[i];
if (!e.has_value() || !e->defined()) { continue; }
if (!e.has_value() || !e->defined()) continue;
if (isFunctionalTensor(e)) {
++functional_count;
}
@ -735,10 +741,10 @@ bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) {
template <typename T>
static bool isFunctionalTensorIListRef(c10::IListRef<T> list) {
if (list.size() == 0) { return false; }
if (list.size() == 0) return false;
auto functional_count = 0;
for (const auto& tensor : list) {
if (!tensor.defined()) { continue; }
if (!tensor.defined()) continue;
if (isFunctionalTensor(tensor)) {
++functional_count;
}
@ -756,28 +762,20 @@ void freeze_functional_tensor(const Tensor& tensor) {
functional_base_impl->freeze_storage();
}
Tensor create_functional_tensor_with_view_meta(
const at::Tensor& view_to_wrap,
const at::Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta,
int64_t out_idx) {
Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) {
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap));
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base));
auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base);
auto meta_ = meta;
if (out_idx != 0) {
// Note [out_idx in ViewMeta]
// When a view op outputs multiple tensors, each output needs its own separate ViewMeta.
// Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function.
meta_ = meta->to_out_index(out_idx);
meta = meta.to_out_idx(out_idx);
}
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta_);
return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta);
}
std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap,
const at::Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta) {
std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) {
std::vector<Tensor> outputs(view_to_wrap.size());
int64_t i = 0;
for (const auto& tensor : view_to_wrap) {
@ -787,22 +785,12 @@ std::vector<Tensor> create_functional_tensor_with_view_meta(
return outputs;
}
void mutate_view_meta(const at::Tensor& self, const std::shared_ptr<functionalization::ViewMeta>& meta) {
void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
self_impl->mutate_view_meta(meta);
}
Tensor apply_view_meta_sequence(
const Tensor& base,
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence) {
Tensor r = base;
for (auto& vm : sequence) {
r = vm->forward(r);
}
return r;
}
// Note [Propagating strides in the functionalization pass]
// In order to properly compute stride information, the functionalization pass
// calls each {view} reference implementations with meta tensors.
@ -896,7 +884,7 @@ void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* s
const auto& ivalue = returns[idx];
if (ivalue.isTensor()) {
const auto& t = ivalue.toTensor();
if (!t.defined()) { continue; }
if (!t.defined()) continue;
at::functionalization::impl::sync(t);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new;

View File

@ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
explicit FunctionalTensorWrapper(
const Tensor& view_value,
const FunctionalTensorWrapper* base,
const std::shared_ptr<functionalization::ViewMeta>& meta);
const functionalization::ViewMeta& meta);
// Get the underlying, actual tensor, that doesn't know anything about
// functionalization.
@ -99,17 +99,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
->are_all_mutations_under_no_grad_or_inference_mode();
}
void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
void maybe_mark_symbolic(const functionalization::ViewMeta& meta) {
is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs;
}
bool is_symbolic() const {
return is_symbolic_;
}
// Retrieves the ViewMeta sequence of this tensor.
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
const;
// Runs the forward_fn of every ViewMeta collected in the current instance
// to some other base.
Tensor apply_view_metas(const Tensor& base);
// Sync's the underlying tensor with its alias, if it's out of date. This
// involves two steps: 1) Apply any pending updates/mutations to the alias 2)
@ -146,8 +146,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// from the base tensor. This method is used by inplace-view ops like
// transpose_. It appends a ViewMeta to the existing stack, and refreshes the
// tensor by replaying the views off of the alias.
void mutate_view_meta(
const std::shared_ptr<at::functionalization::ViewMeta>& meta);
void mutate_view_meta(const at::functionalization::ViewMeta& meta);
// Custom implementation of self.set_(src)
void set__impl(const FunctionalTensorWrapper* other);
@ -286,7 +285,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
bool is_symbolic_ = false;
size_t generation_ = 0;
std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
std::vector<at::functionalization::ViewMeta> view_metas_;
protected:
static void copy_tensor_metadata(
@ -378,20 +377,16 @@ TORCH_API void propagate_xla_data_direct(
Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap,
const Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta,
functionalization::ViewMeta meta,
int64_t out_idx = 0);
std::vector<Tensor> create_functional_tensor_with_view_meta(
ITensorListRef view_to_wrap,
const Tensor& base,
const std::shared_ptr<functionalization::ViewMeta>& meta);
const functionalization::ViewMeta& meta);
void mutate_view_meta(
const Tensor& self,
const std::shared_ptr<functionalization::ViewMeta>& meta);
TORCH_API Tensor apply_view_meta_sequence(
const Tensor& base,
const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
const functionalization::ViewMeta& meta);
void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
void set_sizes_strides_offset(

View File

@ -1,5 +1,3 @@
#include <ATen/FunctionalizeFallbackKernel.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h>
@ -9,6 +7,7 @@
#include <torch/library.h>
#include <c10/util/irange.h>
#include <c10/util/strides.h>
#include <ATen/EmptyTensor.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/ATen.h>
@ -29,31 +28,6 @@
#include <utility>
#endif
namespace at::functionalization {
Tensor resize__ViewMeta::forward(const Tensor& base) {
if (reapply_views) {
return base.as_strided(size, c10::contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
}
}
Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
return base.as_strided_scatter(
mutated_view, size, c10::contiguous_strides(size));
}
Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) {
return at::_unsafe_view_symint(base, size);
}
Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
}
} // namespace at::functionalization
namespace {
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) {
const auto& schema = op.schema();
@ -132,9 +106,7 @@ namespace {
const auto& ivalue = returns[idx];
if (ivalue.isTensor() && should_wrap_outputs) {
const auto& t = ivalue.toTensor();
if (!t.defined()) {
continue;
}
if (!t.defined()) continue;
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isTensorList() && should_wrap_outputs) {
@ -197,8 +169,19 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch
// The output of resizing is equivalent to taking a slice of a larger tensor.
// We have to emulate this "slicing" with an as_strided call.
auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
auto view_meta = std::make_shared<at::functionalization::resize__ViewMeta>(
reapply_views, size.vec());
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
if (reapply_views) {
return base.as_strided(size, c10::contiguous_strides(size));
} else {
return at::as_strided_copy(base, size, c10::contiguous_strides(size));
}
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size));
},
/*has_symbolic_inputs=*/false
);
at::functionalization::impl::mutate_view_meta(self, view_meta);
return self;
}
@ -317,11 +300,17 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
tmp_output = at::_unsafe_view_symint(self_, size);
}
bool has_symbolic_inputs = std::any_of(
size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
auto view_meta =
std::make_shared<at::functionalization::_unsafe_view_ViewMeta>(
has_symbolic_inputs, size.vec());
bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); });
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return at::_unsafe_view_symint(base, size);
},
[size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor {
return at::_unsafe_view_symint(mutated_view, base.sym_sizes());
},
/*has_symbolic_inputs=*/has_symbolic_inputs
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta));
// See Note [Propagating strides in the functionalization pass]

View File

@ -1,58 +0,0 @@
#pragma once
#include <ATen/FunctionalStorageImpl.h>
namespace at::functionalization {
// `ViewMeta` implementation for `resize_` operation.
struct TORCH_API resize__ViewMeta : public ViewMeta {
FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta)
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
bool /* reapply_views */,
const std::vector<int64_t>&);
resize__ViewMeta(const SerializableTuple& tpl)
: resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
resize__ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
: ViewMeta(/*has_symbolic_inputs=*/false),
reapply_views(reapply_views),
size(size) {}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
SerializableTuple to_serializable_tuple() {
return std::make_tuple(reapply_views, size);
}
bool reapply_views;
std::vector<int64_t> size;
};
// `ViewMeta` implementation for `_unsafe_view` operation.
struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta {
FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta)
FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
bool /* has_symbolic_inputs */,
const std::vector<c10::SymInt>&);
_unsafe_view_ViewMeta(const SerializableTuple& tpl)
: _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
_unsafe_view_ViewMeta(
bool has_symbolic_inputs,
const std::vector<c10::SymInt>& size)
: ViewMeta(has_symbolic_inputs), size(size) {}
Tensor forward(const Tensor& base) override;
Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
SerializableTuple to_serializable_tuple() {
return std::make_tuple(has_symbolic_inputs, size);
}
std::vector<c10::SymInt> size;
};
} // namespace at::functionalization

View File

@ -1,22 +1,32 @@
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <c10/core/impl/PyInterpreterHooks.h>
// TODO: delete this
namespace at::impl {
c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::interpreter_ = nullptr;
// The strategy is that all python interpreters attempt to register themselves
// as the main interpreter, but only one wins. Only that interpreter is
// allowed to interact with the C++ dispatcher. Furthermore, when we execute
// logic on that interpreter, we do so hermetically, never setting pyobj field
// on Tensor.
std::atomic<c10::impl::PyInterpreter*>
PythonOpRegistrationTrampoline::interpreter_{nullptr};
c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::getInterpreter() {
return c10::impl::getGlobalPyInterpreter();
return PythonOpRegistrationTrampoline::interpreter_.load();
}
bool PythonOpRegistrationTrampoline::registerInterpreter(
c10::impl::PyInterpreter* interp) {
if (interpreter_ != nullptr) {
c10::impl::PyInterpreter* expected = nullptr;
interpreter_.compare_exchange_strong(expected, interp);
if (expected != nullptr) {
// This is the second (or later) Python interpreter, which means we need
// non-trivial hermetic PyObject TLS
c10::impl::HermeticPyObjectTLS::init_state();
return false;
} else {
return true;
}
interpreter_ = interp;
return true;
}
} // namespace at::impl

View File

@ -2,21 +2,19 @@
#include <ATen/core/dispatch/Dispatcher.h>
// TODO: We can get rid of this
// TODO: this can probably live in c10
namespace at::impl {
// Manages the single Python interpreter instance for PyTorch.
class TORCH_API PythonOpRegistrationTrampoline final {
static c10::impl::PyInterpreter* interpreter_;
static std::atomic<c10::impl::PyInterpreter*> interpreter_;
public:
// Register the Python interpreter. Returns true on first registration,
// false if an interpreter was already registered.
// Returns true if you successfully registered yourself (that means
// you are in the hot seat for doing the operator registrations!)
static bool registerInterpreter(c10::impl::PyInterpreter*);
// Returns the registered interpreter via the global PyInterpreter hooks.
// Returns nullptr if no interpreter has been registered yet.
static c10::impl::PyInterpreter* getInterpreter();
};

View File

@ -149,105 +149,5 @@ static inline void pack_vnni4(
#endif
}
// This is a helper function for transpose_pack_vnni4
// Transform a [4, 16] block (with incontiguous output)
// Src:
// a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 a16
// b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 b16
// c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 c16
// d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 d16
// Dst:
// a1 a2 a3 a4 b1 b2 b3 b4 c1 c2 c3 c4 d1 d2 d3 d4
// a5 a6 a7 a8 b5 b6 b7 b8 c5 c6 c7 c8 d5 d6 d7 d8
// a9 a10 a11 a12 b9 b10 b11 b12 c9 c10 c11 c12 d9 d10 d11 d12
// a13 a14 a15 a16 b13 b14 b15 b16 c13 c14 c15 c16 d13 d14 d15 d16
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
static inline void transpose_vnni4_pad_4x16_block(
const scalar_t* src,
scalar_t* dst,
int64_t ld_src,
int64_t ld_dst,
int krem = 4) {
#if defined(CPU_CAPABILITY_AVX512)
__m128i r[4];
for (int i = 0; i < krem; ++i) {
r[i] = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i * ld_src));
}
for (int i = krem; i < 4; ++i) {
r[i] = _mm_setzero_si128();
}
// Transpose 4x16 bytes using unpack and shuffle
__m128i t0 = _mm_unpacklo_epi32(r[0], r[1]);
__m128i t1 = _mm_unpackhi_epi32(r[0], r[1]);
__m128i t2 = _mm_unpacklo_epi32(r[2], r[3]);
__m128i t3 = _mm_unpackhi_epi32(r[2], r[3]);
__m128i r0 = _mm_unpacklo_epi64(t0, t2);
__m128i r1 = _mm_unpackhi_epi64(t0, t2);
__m128i r2 = _mm_unpacklo_epi64(t1, t3);
__m128i r3 = _mm_unpackhi_epi64(t1, t3);
// Store output
if (krem == 4) {
// normal case
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst), r0);
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst), r1);
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 2), r2);
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 3), r3);
} else {
// masked case
__mmask16 mask = (1ULL << (krem * 4)) - 1;
_mm_mask_storeu_epi8(dst, mask, r0);
_mm_mask_storeu_epi8(reinterpret_cast<__m128i*>(dst + ld_dst), mask, r1);
_mm_mask_storeu_epi8(
reinterpret_cast<__m128i*>(dst + ld_dst * 2), mask, r2);
_mm_mask_storeu_epi8(
reinterpret_cast<__m128i*>(dst + ld_dst * 3), mask, r3);
}
#else
TORCH_CHECK(
false,
"transpose_vnni4_pad_4x16_block is only supported when AVX-512 is supported")
#endif
}
// Do the transpose packing fusion with VNNI4
// Reorder [K, N] → [N/4, K, 4] (VNNI4-style layout for bit8)
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
static inline void transpose_pack_vnni4(
const scalar_t* src,
scalar_t* dst,
int64_t ld_src,
int64_t K,
int64_t N) {
#if defined(CPU_CAPABILITY_AVX512)
TORCH_CHECK(
N % 16 == 0, "N needs to be multiple of 16 for transpose_pack_vnni4");
int64_t bk = 0;
int64_t _K = K / 4 * 4;
for (; bk < _K; bk += 4) {
int64_t bn = 0;
for (; bn < N; bn += 16) {
transpose_vnni4_pad_4x16_block(
src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4);
}
}
// Handle leftover K rows (< 4)
if (K % 4 != 0) {
int krem = K - bk;
int64_t bn = 0;
for (; bn < N; bn += 16) {
transpose_vnni4_pad_4x16_block(
src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4, krem);
}
}
#else
TORCH_CHECK(
false, "transpose_pack_vnni4 is only supported when AVX-512 is supported")
#endif
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -281,9 +281,6 @@ bool CUDAHooks::compiledWithMIOpen() const {
bool CUDAHooks::supportsDilatedConvolutionWithCuDNN() const {
#if AT_CUDNN_ENABLED()
if (!hasCUDA()) {
return false;
}
// NOTE: extra parenthesis around numbers disable clang warnings about
// dead code
return true;
@ -294,9 +291,6 @@ bool CUDAHooks::supportsDilatedConvolutionWithCuDNN() const {
bool CUDAHooks::supportsDepthwiseConvolutionWithCuDNN() const {
#if AT_CUDNN_ENABLED()
if (!hasCUDA()) {
return false;
}
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
// Check for Volta cores
if (prop->major >= 7) {
@ -311,9 +305,6 @@ bool CUDAHooks::supportsDepthwiseConvolutionWithCuDNN() const {
bool CUDAHooks::supportsBFloat16ConvolutionWithCuDNNv8() const {
#if AT_CUDNN_ENABLED()
if (!hasCUDA()) {
return false;
}
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
// Check for Volta cores
if (prop->major >= 8) {

View File

@ -465,11 +465,8 @@ inline bool mps_conv_use_channels_last(const at::Tensor& input, const at::Tensor
return false;
}
auto is_channel_last = [](const at::Tensor& t) {
auto fmt = t.suggest_memory_format();
return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d;
};
return is_channel_last(input) || is_channel_last(weight);
auto fmt = input.suggest_memory_format();
return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d;
}
} // namespace at::native

View File

@ -32,6 +32,10 @@
#include <ATen/native/mkldnn/Utils.h>
#endif
#ifdef USE_MPS
#include <ATen/mps/MPSDevice.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
@ -1425,8 +1429,12 @@ static inline at::MemoryFormat determine_backend_memory_format(
}
break;
case ConvBackend::Mps:
case ConvBackend::MpsTranspose:
if (mps_conv_use_channels_last(input, weight)) {
#ifdef USE_MPS
if (!mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) {
break;
}
#endif
backend_memory_format = (k == 5) ? MemoryFormat::ChannelsLast3d : MemoryFormat::ChannelsLast;
}
break;

View File

@ -9,7 +9,6 @@
#include <ATen/native/TransposeType.h>
#include <ATen/native/Unfold3d.h>
#include <c10/util/irange.h>
#include <c10/util/safe_numerics.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -175,23 +174,6 @@ static inline void slow_conv3d_shape_check(
const int64_t input_height = input.size(dim_height);
const int64_t input_width = input.size(dim_width);
constexpr int64_t MAX_SAFE_PAD = (1LL << 61);
TORCH_CHECK_VALUE(
pad_height <= MAX_SAFE_PAD,
"Padding height too large: pad_height=",
pad_height);
TORCH_CHECK_VALUE(
pad_width <= MAX_SAFE_PAD,
"Padding width too large: pad_width=",
pad_width);
TORCH_CHECK_VALUE(
pad_depth <= MAX_SAFE_PAD,
"Padding depth too large: pad_depth=",
pad_depth);
const int64_t exact_input_depth = input_depth + 2 * pad_depth;
const int64_t exact_input_height = input_height + 2 * pad_height;
const int64_t exact_input_width = input_width + 2 * pad_width;
@ -239,14 +221,6 @@ static inline void slow_conv3d_shape_check(
output_width,
"). Output size is too small");
uint64_t kernel_product;
TORCH_CHECK(
!c10::mul_overflows(kernel_height, kernel_width, &kernel_product),
"Kernel height x width product is too large: kernel_height=",
kernel_height,
", kernel_width=",
kernel_width);
if (weight.defined()) {
int64_t n_input_plane = weight.size(1);
if (weight.dim() == 2) {

View File

@ -23,7 +23,6 @@
#include <ATen/ops/linspace.h>
#endif
#include <cmath>
#include <numeric>
#include <tuple>
#include <vector>
@ -203,46 +202,6 @@ select_outer_bin_edges(const Tensor& input, std::optional<c10::ArrayRef<double>>
return std::make_pair(leftmost_edges, rightmost_edges);
}
/* Bin edges correction based on the precision representation.
* To maintain the backward compatibility we take max(std::nextafter<>, +1)
* and min(std::nextafter<>, -1) for scalar types. For other types +/- 1 as usual.
*/
void bins_edges_correction(const ScalarType& t, double &leftmost_edge, double &rightmost_edge)
{
#define UPDATE_WITH_LIMIT(real_type, scalartype) \
case ScalarType::scalartype: \
leftmost_edge = std::min( \
static_cast<double>( \
std::nexttoward( \
static_cast<real_type>(leftmost_edge), \
std::numeric_limits<real_type>::lowest() \
) \
), \
leftmost_edge - 1. \
); \
rightmost_edge = std::max( \
static_cast<double>( \
std::nexttoward( \
static_cast<real_type>(rightmost_edge), \
std::numeric_limits<real_type>::max() \
) \
), \
rightmost_edge + 1. \
); \
break;
switch (t) {
UPDATE_WITH_LIMIT(double, Double)
UPDATE_WITH_LIMIT(float, Float)
default:
// Fallback to the default behavior for other types
leftmost_edge -= 1;
rightmost_edge += 1;
}
#undef UPDATE_WITH_LIMIT
}
/* histc's version of the logic for outermost bin edges.
*/
std::pair<double, double> histc_select_outer_bin_edges(const Tensor& input,
@ -257,7 +216,8 @@ std::pair<double, double> histc_select_outer_bin_edges(const Tensor& input,
}
if (leftmost_edge == rightmost_edge) {
bins_edges_correction(input.dtype().toScalarType(), leftmost_edge, rightmost_edge);
leftmost_edge -= 1;
rightmost_edge += 1;
}
TORCH_CHECK(!(std::isinf(leftmost_edge) || std::isinf(rightmost_edge) ||

View File

@ -42,19 +42,6 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) {
});
}
#ifdef USE_ROCM
void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) {
return static_cast<float>(value);
});
}
void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) {
return static_cast<float>(value);
});
}
#endif
void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
ScalarType dtype = iter.dtype(0);
ScalarType other_dtype = iter.dtype(1);
@ -200,17 +187,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
} else {
float16_copy_kernel_cuda(iter);
}
}
#ifdef USE_ROCM
else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) {
if (iter.dtype(1) == kBFloat16) {
bfloat16tofloat32_copy_kernel_cuda(iter);
} else {
float16tofloat32_copy_kernel_cuda(iter);
}
}
#endif
else if (isBitsType(dtype)) {
} else if (isBitsType(dtype)) {
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {

View File

@ -52,7 +52,9 @@ static void fill_depthwise_conv_desc(MPSGraphDepthwiseConvolution3DOpDescriptor*
NSUInteger dilationRateInX,
NSUInteger dilationRateInY,
NSUInteger paddingHorizontal,
NSUInteger paddingVertical) {
NSUInteger paddingVertical,
c10::MemoryFormat memory_format,
NSUInteger groups) {
descriptor_.strides =
@[ @1, [[NSNumber alloc] initWithInteger:strideInY], [[NSNumber alloc] initWithInteger:strideInX] ];
descriptor_.dilationRates =
@ -101,7 +103,7 @@ static void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
descriptor_.groups = groups;
}
static Tensor _mps_convolution_impl(const Tensor& input_t,
static Tensor _mps_convolution_impl(const Tensor& input_t_,
const Tensor& weight_t,
const std::optional<Tensor>& bias_opt,
IntArrayRef padding,
@ -109,15 +111,12 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
IntArrayRef dilation,
int64_t groups,
std::optional<IntArrayRef> input_shape) {
constexpr auto kChannelsLast = MemoryFormat::ChannelsLast;
constexpr auto kContiguous = MemoryFormat::Contiguous;
const bool is_macos_15_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
const bool is3DConv = input_t.dim() == 5;
const auto memory_format = input_t.suggest_memory_format();
const auto input_suggested_layout = memory_format == kChannelsLast && is_macos_15_plus ? kChannelsLast : kContiguous;
const bool is_channels_last = mps_conv_use_channels_last(input_t, weight_t) && !is3DConv;
const bool bias_defined = bias_opt ? bias_opt->defined() : false;
const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
Tensor input_t = input_t_;
bool is3DConv = input_t.dim() == 5;
if (!is_macOS_15_0_or_newer || is3DConv) {
input_t = input_t.contiguous();
}
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
@ -127,6 +126,15 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
checkAllSameType(c, {input, weight});
checkAllSameGPU(c, {input, weight});
bool bias_defined;
if (bias_opt == std::nullopt)
bias_defined = false;
else
bias_defined = bias_opt->defined();
auto memory_format = input_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv;
auto output_t =
at::empty(input_shape.has_value() ? input_shape.value()
: conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation),
@ -134,18 +142,12 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
std::nullopt,
kMPS,
std::nullopt,
is_channels_last ? kChannelsLast : kContiguous);
is_macOS_15_0_or_newer ? memory_format : MemoryFormat::Contiguous);
if (output_t.numel() == 0) {
return output_t;
}
TensorArg output{output_t, "result", 0};
// TODO: Remove me when MacOS-14 is no longer supported
std::optional<Tensor> output_c;
if (!is_macos_15_plus && is_channels_last) {
output_c = at::empty_like(output_t, output_t.options().memory_format(kContiguous));
}
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_1_PLUS)) {
// On macOS < 15.1, MPS convolution kernel does not support output channels > 2^16
for (auto elem : output_t.sizes()) {
@ -184,22 +186,32 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
getArrayRefString(dilation),
getArrayRefString(padding),
groups,
input_suggested_layout == kChannelsLast,
is_channels_last,
mps::getTensorsStringKey({input_t, weight_t}),
bias_defined,
bias_shape_key);
auto inputShape = mps::getMPSShape(input_t, input_suggested_layout);
auto outputShape = mps::getMPSShape(output_t, input_suggested_layout);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
bool isDepthwiseConv =
(groups > 1 && weight_t.size(1) == 1) && input_t.dim() >= 4 && weight_t.dim() >= 4 && !is_channels_last;
MPSShape* inputShape = mps::getMPSShape(input_t, memory_format);
MPSShape* outputShape = mps::getMPSShape(output_t, memory_format);
MPSNDArray* inputNDArray = nil;
MPSNDArray* outputNDArray = nil;
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input_t), inputShape);
auto weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
MPSGraphTensor* outputTensor = nil;
if (input_t.is_contiguous(memory_format) && output_t.is_contiguous(memory_format) && is_macOS_15_0_or_newer) {
inputNDArray = getMPSNDArray(input_t, inputShape);
outputNDArray = getMPSNDArray(output_t, outputShape);
}
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSShape* weightShape = mps::getMPSShape(weight_t);
bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) && inputShape.count >= 4 &&
weightShape.count >= 4 && !is_channels_last);
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input_t.scalar_type()), inputShape);
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
MPSGraphTensor* outputTensor;
if (is3DConv) {
auto conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease];
MPSGraphConvolution3DOpDescriptor* conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease];
fill_conv3d_desc(conv3dDescriptor_,
stride[2],
stride[1],
@ -217,9 +229,17 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
descriptor:conv3dDescriptor_
name:nil];
} else if (isDepthwiseConv) {
auto depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
fill_depthwise_conv_desc(
depthWiseConv3dDescriptor_, stride[1], stride[0], dilation[1], dilation[0], padding[1], padding[0]);
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
stride[1],
stride[0],
dilation[1],
dilation[0],
padding[1],
padding[0],
memory_format,
groups);
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
dimension:-3
@ -238,7 +258,7 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
dilation[0],
padding[1],
padding[0],
input_suggested_layout,
memory_format,
groups);
outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor
@ -250,6 +270,13 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
MPSGraphTensor* biasTensor = nil;
if (bias_defined) {
biasTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(bias_opt.value()));
}
if (is_channels_last && !is_macOS_15_0_or_newer) {
outputTensor = mps::convertNHWCtoNCHW(mpsGraph, outputTensor);
}
if (bias_defined) {
outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor secondaryTensor:biasTensor name:nil];
}
newCachedGraph->inputTensor_ = inputTensor;
@ -258,26 +285,27 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
newCachedGraph->outputTensor_ = outputTensor;
});
auto inputPlaceholder = input_suggested_layout == kContiguous
? Placeholder(cachedGraph->inputTensor_, output_c || is3DConv ? input_t.contiguous() : input_t)
: Placeholder(cachedGraph->inputTensor_, getMPSNDArray(input_t, inputShape));
auto outputPlaceholder = input_suggested_layout == kContiguous
? Placeholder(cachedGraph->outputTensor_, output_c ? *output_c : output_t)
: Placeholder(cachedGraph->outputTensor_, getMPSNDArray(output_t, outputShape));
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, output_c ? weight_t.contiguous() : weight_t);
auto inputPlaceholder = inputNDArray ? Placeholder(cachedGraph->inputTensor_, inputNDArray)
: Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
auto biasPlaceholder = Placeholder();
// Reshape the bias to be broadcastable with output of conv2d or conv3d
if (bias_defined) {
if (is3DConv) {
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias_opt->view({1, bias_shape[0], 1, 1, 1}));
} else if (input_suggested_layout == kChannelsLast) {
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias_opt->view({1, 1, 1, bias_shape[0]}));
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1, 1}));
} else {
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias_opt->view({1, bias_shape[0], 1, 1}));
if (is_channels_last && is_macOS_15_0_or_newer) {
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, 1, 1, bias_shape[0]}));
} else {
biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1}));
}
}
}
auto outputPlaceholder = outputNDArray ? Placeholder(cachedGraph->outputTensor_, outputNDArray)
: Placeholder(cachedGraph->outputTensor_, output_t);
auto feeds = [[[NSMutableDictionary alloc] initWithCapacity:3] autorelease];
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
[[[NSMutableDictionary alloc] initWithCapacity:3] autorelease];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData();
if (bias_defined) {
@ -287,10 +315,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (output_c) {
output_t.copy_(*output_c);
}
return output_t;
}
@ -327,21 +351,14 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
TensorArg grad_output{grad_output_t, "grad_output", 1}, weight{weight_t, "weight", 2};
checkAllSameType(c, {grad_output, weight});
checkAllSameGPU(c, {grad_output, weight});
constexpr auto kChannelsLast = at::MemoryFormat::ChannelsLast;
bool is_channels_last = mps_conv_use_channels_last(grad_output_t, weight_t) && !is3DConv;
auto grad_input_t =
at::empty(input_size, grad_output_t.options(), is_channels_last ? std::optional(kChannelsLast) : std::nullopt);
auto memory_format = grad_output_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv;
auto grad_input_t = at::empty(input_size, grad_output_t.options(), std::nullopt);
// Avoid "grad_input" when this is being used as transposed convolution
TensorArg grad_input{grad_input_t, "result", 0};
convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
// TODO: Remove me when MacOS-14 is no longer supported
std::optional<Tensor> grad_input_c;
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && is_channels_last) {
grad_input_c = at::empty_like(grad_input_t, grad_input_t.options().memory_format(MemoryFormat::Contiguous));
}
// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
@ -353,6 +370,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
// Add backward with input
@autoreleasepool {
MPSStream* stream = getCurrentMPSStream();
MPSShape* mps_input_shape = getMPSShape(input_size);
std::string key = fmt::format("mps_{}_convolution_backward_input:{}:{}:{}:{}:{}:{}",
is3DConv ? "3d_" : "",
@ -393,8 +411,15 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
} else if (isDepthwiseConv) {
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
fill_depthwise_conv_desc(
depthWiseConv3dDescriptor_, stride[1], stride[0], dilation[1], dilation[0], padding[1], padding[0]);
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
stride[1],
stride[0],
dilation[1],
dilation[0],
padding[1],
padding[0],
at::MemoryFormat::Contiguous,
groups);
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
dimension:-3
withDimension:-4
@ -429,18 +454,14 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
newCachedGraph->gradInputTensor_ = gradInputTensor;
});
auto gradOutputPlaceholder =
Placeholder(cachedGraph->gradOutputTensor_, grad_input_c ? grad_output_t.contiguous() : grad_output_t);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, grad_input_c ? weight_t.contiguous() : weight_t);
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input_c ? *grad_input_c : grad_input_t);
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t);
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);
auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, weightsPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (grad_input_c) {
grad_input_t.copy_(*grad_input_c);
}
return grad_input_t;
return *grad_input;
}
static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
@ -453,11 +474,9 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
bool bias_defined) {
using namespace at::native::mps;
using namespace mps;
const bool is3DConv = input_t.dim() == 5;
bool is3DConv = input_t.dim() == 5;
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
CheckedFrom c = "mps_convolution_backward_weights";
constexpr auto kChannelsLast = at::MemoryFormat::ChannelsLast;
bool is_channels_last = mps_conv_use_channels_last(input_t, grad_output_t) && !is3DConv;
// For uniformity with everything else, although it seems grad_weight
// would be unambiguous too.
@ -468,8 +487,7 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
checkAllSameGPU(c, {grad_output, input});
auto grad_weight_t =
at::empty(weight_size, grad_output_t.options(), is_channels_last ? std::optional(kChannelsLast) : std::nullopt);
at::empty(weight_size, grad_output_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
TensorArg grad_weight{grad_weight_t, "result", 0};
convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups);
@ -482,23 +500,16 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
MPSGraphTensor* gradWeightTensor_ = nil;
};
// TODO: Remove me when MacOS-14 is no longer supported
std::optional<Tensor> grad_weight_c;
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && is_channels_last) {
grad_weight_c = at::empty_like(grad_weight_t, grad_weight_t.options().memory_format(MemoryFormat::Contiguous));
}
@autoreleasepool {
MPSStream* stream = getCurrentMPSStream();
MPSShape* mps_weight_shape = getMPSShape(weight_size);
std::string key = fmt::format("mps_{}convolution_backward_weights:{}:{}:{}:{}:{}:{}",
std::string key = fmt::format("mps_{}convolution_backward_weights:{}:{}:{}:{}:{}",
is3DConv ? "3d_" : "",
getArrayRefString(stride),
getArrayRefString(dilation),
getArrayRefString(padding),
groups,
is_channels_last,
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}));
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSShape* inputShape = getMPSShape(input_t);
@ -530,8 +541,15 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
} else if (isDepthwiseConv) {
MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
[[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
fill_depthwise_conv_desc(
depthWiseConv3dDescriptor_, stride[1], stride[0], dilation[1], dilation[0], padding[1], padding[0]);
fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
stride[1],
stride[0],
dilation[1],
dilation[0],
padding[1],
padding[0],
at::MemoryFormat::Contiguous,
groups);
NSNumber* outputFeatChannelDim = mps_weight_shape[0];
MPSShape* weightShapeTranspose = @[ @1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3] ];
MPSGraphTensor* gradWeightTensorTranspose =
@ -565,19 +583,14 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
newCachedGraph->gradWeightTensor_ = gradWeightTensor;
});
auto gradOutputPlaceholder =
Placeholder(cachedGraph->gradOutputTensor_, grad_weight_c ? grad_output_t.contiguous() : grad_output_t);
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, grad_weight_c ? input_t.contiguous() : input_t);
auto outputPlaceholder =
Placeholder(cachedGraph->gradWeightTensor_, grad_weight_c ? *grad_weight_c : grad_weight_t);
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t);
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t);
auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, inputPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (grad_weight_c) {
grad_weight_t.copy_(*grad_weight_c);
}
return grad_weight_t;
}

View File

@ -158,46 +158,12 @@ c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
return packed_ptr;
}
#ifdef USE_FBGEMM
namespace {
/// Number of columns in the rowwise min/max buffer passed to the quantization function(s)
constexpr int kRowwiseMinMaxNumCols = 2;
bool _validate_rowwise_min_max(
const at::Tensor& weight,
const std::optional<at::Tensor>& rowwise_min_max_opt) {
const auto is_valid_rowwise_min_max = rowwise_min_max_opt.has_value();
if (is_valid_rowwise_min_max) {
TORCH_CHECK(
(rowwise_min_max_opt->dim() == 2 &&
rowwise_min_max_opt->size(0) == weight.size(0) &&
rowwise_min_max_opt->size(1) == kRowwiseMinMaxNumCols),
"'rowwise_min_max' must be a 2D tensor with shape [num_rows(weight), 2].");
}
return is_valid_rowwise_min_max;
}
auto _get_rowwise_min_max_contig(
const std::optional<at::Tensor>& rowwise_min_max_opt) {
return rowwise_min_max_opt.has_value()
? rowwise_min_max_opt->expect_contiguous(rowwise_min_max_opt->suggest_memory_format())
: at::borrow_from_optional_tensor(rowwise_min_max_opt);
}
}
#endif // USE_FBGEMM
namespace at::native {
// Note - This is a temporary pack function for embedding bag which quantizes
// and packs the float weight tensor. In the next step it will be replaced by a
// quantize and pack function once we support FP scale and FP zero_point
//
// The optional rowwise_min_max argument is to support callers to pass in the min/max
// values of the weight tensor. If the rowwise_min_max is not provided, the min/max
// values will be computed from the weight tensor.
//
// Python example examining a packed 8bit zero_point and scale:
//
// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]],
@ -255,10 +221,7 @@ namespace at::native {
//
// [[50. , 60.00000035],
// [70. , 80.00000035]]])
Tensor& qembeddingbag_byte_prepack_out(
Tensor& output,
const Tensor& weight,
const std::optional<Tensor>& rowwise_min_max_opt) {
Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
// The "last" dimension of an N-Dimensioned batch of embedding bags is
// quantization channel. E.g. for a 2D embedding bag, this has
// [ row, col ] dimensions, for batched of embedding bags, dimensions might be
@ -293,16 +256,9 @@ Tensor& qembeddingbag_byte_prepack_out(
auto* output_data = output.data_ptr<uint8_t>();
#ifdef USE_FBGEMM
// Move these outside of the ifdef when we support non-FBGEMM flow.
const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt);
const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt);
if (weight_contig->scalar_type() == at::ScalarType::Half) {
const auto weight_data =
static_cast<fbgemm::float16*>(weight_contig->data_ptr());
const auto rowwise_min_max_data = is_valid_rowwise_min_max
? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr())
: nullptr;
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<
@ -310,21 +266,17 @@ Tensor& qembeddingbag_byte_prepack_out(
weight_data + start_idx * embedding_cols,
end_idx - start_idx,
embedding_cols,
output_data + start_idx * output_columns,
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
output_data + start_idx * output_columns);
});
} else {
const auto weight_data = weight_contig->data_ptr<float>();
const auto rowwise_min_max_data =
is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr;
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
weight_data + start_idx * embedding_cols,
end_idx - start_idx,
embedding_cols,
output_data + start_idx * output_columns,
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
output_data + start_idx * output_columns);
});
}
@ -374,22 +326,6 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
return output;
}
static Tensor qembeddingbag_byte_prepack_with_rowwise_min_max(
const Tensor& weight,
const Tensor& rowwise_min_max) {
const auto weight_contig =
weight.expect_contiguous(weight.suggest_memory_format());
Tensor output = at::detail::empty_cpu(
{0},
at::kByte,
weight_contig->layout(),
weight_contig->device(),
std::nullopt,
std::nullopt);
qembeddingbag_byte_prepack_out(output, weight, rowwise_min_max);
return output;
}
Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
const auto weight_contig =
weight.expect_contiguous(weight.suggest_memory_format());
@ -399,7 +335,7 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
"'embedding_bag_byte_prepack' only support float32 or float16.");
const auto weight_sizes = weight.sym_sizes();
const auto cols_dim = weight.ndimension() - 1;
const auto& embedding_cols = weight_sizes[cols_dim];
const auto embedding_cols = weight_sizes[cols_dim];
// Add 8 bytes per column to store FP32 scale and zero_point per row.
const auto output_columns = embedding_cols + 2 * sizeof(float);
@ -423,8 +359,7 @@ Tensor _qembeddingbag_nbit_prepack_helper(
int bit_width,
const bool optimized_qparams,
const int64_t nbins,
const double ratio,
const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt) {
const double ratio) {
TORCH_CHECK(
weight.scalar_type() == at::ScalarType::Float ||
weight.scalar_type() == at::ScalarType::Half,
@ -466,17 +401,10 @@ Tensor _qembeddingbag_nbit_prepack_helper(
auto* output_data = output.data_ptr<uint8_t>();
#ifdef USE_FBGEMM
// Move these outside of the ifdef when we support non-FBGEMM flow.
const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt);
const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt);
if (!optimized_qparams) {
if (weight_contig.scalar_type() == at::ScalarType::Half) {
const auto weight_data =
static_cast<fbgemm::float16*>(weight_contig.data_ptr());
const auto rowwise_min_max_data = is_valid_rowwise_min_max
? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr())
: nullptr;
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<
@ -485,13 +413,10 @@ Tensor _qembeddingbag_nbit_prepack_helper(
weight_data + start_idx * embedding_cols,
end_idx - start_idx,
static_cast<int>(embedding_cols),
output_data + start_idx * output_shape[1],
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
output_data + start_idx * output_shape[1]);
});
} else {
const auto weight_data = weight_contig.data_ptr<float>();
const auto rowwise_min_max_data =
is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr;
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
@ -499,8 +424,7 @@ Tensor _qembeddingbag_nbit_prepack_helper(
weight_data + start_idx * embedding_cols,
end_idx - start_idx,
static_cast<int>(embedding_cols),
output_data + start_idx * output_shape[1],
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
output_data + start_idx * output_shape[1]);
});
}
} else {
@ -590,16 +514,6 @@ Tensor qembeddingbag_4bit_prepack(
weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio);
}
Tensor qembeddingbag_4bit_prepack_with_rowwise_min_max(
const Tensor& weight,
const Tensor& rowwise_min_max,
const bool optimized_qparams,
const int64_t nbins,
const double ratio) {
return _qembeddingbag_nbit_prepack_helper(
weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max);
}
// Applies 2-bit row-wise quantization by determining the range
// (maximum - minimum) and bias (minimum value) of each row in the input
// matrix, and then scaling each element to an 2-bit number between 0 and
@ -617,16 +531,6 @@ Tensor qembeddingbag_2bit_prepack(
weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio);
}
Tensor qembeddingbag_2bit_prepack_with_rowwise_min_max(
const Tensor& weight,
const Tensor& rowwise_min_max,
const bool optimized_qparams,
const int64_t nbins,
const double ratio) {
return _qembeddingbag_nbit_prepack_helper(
weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max);
}
class QEmbeddingPackWeights final {
public:
static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(const at::Tensor& weight) {
@ -638,21 +542,12 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"),
TORCH_FN(qembeddingbag_byte_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack_with_rowwise_min_max"),
TORCH_FN(qembeddingbag_byte_prepack_with_rowwise_min_max));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"),
TORCH_FN(qembeddingbag_4bit_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max"),
TORCH_FN(qembeddingbag_4bit_prepack_with_rowwise_min_max));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"),
TORCH_FN(qembeddingbag_2bit_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max"),
TORCH_FN(qembeddingbag_2bit_prepack_with_rowwise_min_max));
}
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {

View File

@ -3,10 +3,7 @@
namespace at::native {
Tensor& qembeddingbag_byte_prepack_out(
Tensor& output,
const Tensor& weight,
const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt);
Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight);
Tensor qembeddingbag_byte_prepack(const Tensor& weight);

View File

@ -121,12 +121,9 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});

View File

@ -120,7 +120,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
// buffer (in bytes)
size_t orig_m = sparse_input.size(0);
size_t div = orig_m * sparse_input.itemsize();
size_t new_n = (compressed_size + div - 1) / div; // ceil(s,d) = (s+d-1)/d
size_t new_n = (compressed_size + div - 1) / div; // floor
auto compressed_tensor = sparse_input.new_empty({(int64_t)orig_m, (int64_t)new_n});
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
@ -155,7 +155,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
handle_initialized = true;
}
// cuSPARSELt constructs
// cupsarselt constructs
cusparseLtMatmulDescriptor_t matmul;
cusparseLtMatmulPlan_t plan;
cusparseLtMatmulAlgSelection_t alg_sel;

View File

@ -176,28 +176,6 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
}
return false;
}
if constexpr(caller_is_meff) {
bool is_half = (params.query.dtype() == at::kHalf) ||
(params.query.dtype() == at::kBFloat16);
const int64_t alignment = is_half ? 8 : 4;
if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
value_size_last % alignment == 0 && value_size_last > 0)) {
if (debug) {
TORCH_WARN(
"Mem efficient attention requires last dimension of inputs to be divisible by ",
alignment,
". ",
"Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.sym_size(-1),
", Value.size(-1): ",
params.value.sym_size(-1),
" instead.");
}
return false;
}
}
return true;
}

View File

@ -462,11 +462,10 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
using sdp::aotriton_adapter::mk_atomictensor;
using sdp::aotriton_adapter::cast_dtype;
at::Tensor atomic_counter;
if (is_causal) {
atomic_counter = at::zeros({1}, q.options().dtype(at::kInt));
atomic_counter = at::zeros({1}, q.options());
}
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
@ -475,7 +474,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
auto nullscalar = mk_philoxtensor(nullptr);
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
#if AOTRITON_V3_API
using aotriton::v3::flash::CausalType;

View File

@ -2,12 +2,22 @@
// ${generated_comment}
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/Tensor.h>
namespace at {
namespace functionalization {
enum class InverseReturnMode {
/// Specifies that functional inverses should always return a view.
AlwaysView,
/// Specifies that functional inverses should always return a non-view / copy.
NeverView,
/// Specifies that functional inverses should return a view unless a (copying) scatter
/// inverse exists, in which case that will be used instead.
/// This avoids as_strided() calls that can be difficult for subclasses to handle.
ViewOrScatterInverse,
};
struct FunctionalInverses {
${view_inverse_declarations}

View File

@ -4,7 +4,7 @@
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/ViewMetaClasses.h>
#include <ATen/FunctionalInverses.h>
#include <ATen/MemoryOverlap.h>
#include <torch/library.h>

View File

@ -1,19 +0,0 @@
// ${generated_comment}
#include <ATen/FunctionalInverses.h>
#include <ATen/ViewMetaClasses.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Operators.h>
#include <ATen/NativeFunctions.h>
#else
${op_headers}
#endif
namespace at {
namespace functionalization {
${view_meta_implementations}
} // namespace functionalization
} // namespace at

View File

@ -1,12 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include <ATen/FunctionalStorageImpl.h>
namespace at {
namespace functionalization {
${view_meta_declarations}
} // namespace functionalization
} // namespace at

View File

@ -1,11 +0,0 @@
#include <ATen/ViewMetaClasses.h>
#include <torch/csrc/functionalization/Module.h>
namespace torch::functionalization {
void initGenerated(PyObject* module) {
auto functionalization = py::handle(module).cast<py::module>();
$view_meta_bindings
}
} // namespace torch::functionalization

View File

@ -1561,38 +1561,6 @@ namespace {
<< "Failure Details:\nTest Seed to reproduce: " << seed;
}
}
#endif
#if defined(CPU_CAPABILITY_AVX512)
TYPED_TEST(Quantization8BitTests, TransposePackVNNI4) {
using VT = ValueType<TypeParam>;
constexpr auto K = 197;
constexpr auto N = 64;
constexpr auto L = K * N;
constexpr auto ld_src = N;
constexpr auto ld_dst = K * 4;
CACHE_ALIGN VT x[L];
CACHE_ALIGN VT y[L];
CACHE_ALIGN VT ref[L];
auto seed = TestSeed();
ValueGen<VT> generator(VT(-100), VT(100), seed);
for (const auto i : c10::irange(L)) {
x[i] = generator.get();
}
at::vec::transpose_pack_vnni4(x, y, ld_src, K, N);
int64_t _N = N / 4;
for (int64_t k = 0; k < K; k++) {
for(int64_t n = 0; n < _N; n++) {
for(int64_t l = 0; l < 4; l++) {
ref[n * ld_dst + k * 4 + l] =
c10::load(&(x[k * ld_src + n * 4 + l]));
}
}
}
for (const auto i : c10::irange(L)) {
ASSERT_EQ(y[i], ref[i])
<< "Failure Details:\nTest Seed to reproduce: " << seed;
}
}
#endif
TYPED_TEST(FunctionalTests, Map) {
using vec = TypeParam;

View File

@ -318,7 +318,7 @@ timm_vovnet,pass,0
torch_multimodal_clip,pass,0
torch_multimodal_clip,pass,3

1 name accuracy graph_breaks
318
319
320
321
322
323
324

View File

@ -391,8 +391,6 @@ def get_aten_generated_files(enabled_backends):
"CompositeExplicitAutogradFunctions_inl.h",
"CompositeExplicitAutogradNonFunctionalFunctions.h",
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
"ViewMetaClasses.h",
"ViewMetaClasses.cpp",
"VmapGeneratedPlumbing.h",
"core/ATenOpList.cpp",
"core/TensorBody.h",
@ -1194,7 +1192,6 @@ def define_buck_targets(
"NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]",
"Operators.h": ":gen_aten[Operators.h]",
"RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]",
"ViewMetaClasses.h": ":gen_aten[ViewMetaClasses.h]",
"core/TensorBody.h": ":gen_aten[core/TensorBody.h]",
"core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]",
"core/enum_tag.h": ":gen_aten[core/enum_tag.h]",

View File

@ -118,9 +118,6 @@ def define_targets(rules):
":LazyNonNativeIr.h",
":RegisterDispatchDefinitions.ini",
":RegisterDispatchKey.cpp",
":ViewMetaClassesPythonBinding.cpp",
":ViewMetaClasses.cpp",
":ViewMetaClasses.h",
":native_functions.yaml",
":shape_inference.h",
":tags.yaml",
@ -173,7 +170,6 @@ GENERATED_H = [
"FunctionalInverses.h",
"RedispatchFunctions.h",
"RegistrationDeclarations.h",
"ViewMetaClasses.h",
"VmapGeneratedPlumbing.h",
]
@ -250,7 +246,6 @@ GENERATED_CPP = [
"RegisterFunctionalization_1.cpp",
"RegisterFunctionalization_2.cpp",
"RegisterFunctionalization_3.cpp",
"ViewMetaClasses.cpp",
]
GENERATED_CPP_CORE = [
@ -312,7 +307,6 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
"torch/csrc/autograd/generated/python_variable_methods.cpp",
"torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
]
GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP

View File

@ -1010,7 +1010,6 @@ libtorch_python_core_sources = [
"torch/csrc/utils/disable_torch_function.cpp",
"torch/csrc/utils/verbose.cpp",
"torch/csrc/cpu/Module.cpp",
"torch/csrc/functionalization/Module.cpp",
"torch/csrc/instruction_counter/Module.cpp",
"torch/nativert/python/Bindings.cpp",
] + lazy_tensor_core_python_sources
@ -1053,7 +1052,6 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
"torch/csrc/autograd/generated/python_variable_methods.cpp",
"torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp",
]]
_libtorch_python_sources.extend(libtorch_python_core_sources)

View File

@ -3244,7 +3244,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
are_equal<sizeof(autograd_meta_), 4, FieldNameEnum::autograd_meta_>();
are_equal<sizeof(extra_meta_), 4, FieldNameEnum::extra_meta_>();
are_equal<sizeof(version_counter_), 4, FieldNameEnum::version_counter_>();
are_equal<sizeof(pyobj_slot_), 4, FieldNameEnum::pyobj_slot_>();
are_equal<sizeof(pyobj_slot_), 8, FieldNameEnum::pyobj_slot_>();
is_le<sizeof(sizes_and_strides_), 88, FieldNameEnum::sizes_and_strides_>();
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
@ -3269,7 +3269,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
is_le<sizeof(autograd_meta_), 16, FieldNameEnum::autograd_meta_>();
is_le<sizeof(extra_meta_), 16, FieldNameEnum::extra_meta_>();
are_equal<sizeof(version_counter_), 8, FieldNameEnum::version_counter_>();
are_equal<sizeof(pyobj_slot_), 8, FieldNameEnum::pyobj_slot_>();
are_equal<sizeof(pyobj_slot_), 16, FieldNameEnum::pyobj_slot_>();
are_equal<sizeof(sizes_and_strides_), 88, FieldNameEnum::sizes_and_strides_>();
are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();

View File

@ -0,0 +1,21 @@
#include <c10/core/impl/HermeticPyObjectTLS.h>
namespace c10::impl {
thread_local static std::atomic<bool> hermeticPyObjectState{false};
std::atomic<bool> HermeticPyObjectTLS::haveState_{false};
void HermeticPyObjectTLS::set_state(bool state) {
hermeticPyObjectState = state;
}
bool HermeticPyObjectTLS::get_tls_state() {
return hermeticPyObjectState;
}
void HermeticPyObjectTLS::init_state() {
haveState_ = true;
}
} // namespace c10::impl

View File

@ -0,0 +1,62 @@
#pragma once
#include <c10/macros/Export.h>
#include <atomic>
namespace c10::impl {
// This TLS controls whether or not we permanently associate PyObject
// with Tensor the first time it is allocated. When hermetic PyObject
// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor,
// meaning you get a distinct PyObject whenever you execute the code in
// question.
struct C10_API HermeticPyObjectTLS {
static void set_state(bool state);
static bool get_state() {
// Hypothetical fastpath if torchdeploy/multipy // codespell:ignore multipy
// isn't used. Per
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// this qualifies relaxed access because it is a single-location data
// structure (only the boolean here).
//
// Forgetting about data races for a moment, is there a logical race?
//
// - Boolean only ever transitions from false to true. So the
// critical situation is when one interpreter is already running
// when a second interpreter switches haveState from false to true.
//
// - The first interpreter is indifferent whether or not it sees
// hasState true/false; obviously false works (this is what the
// interpreter was previously using; more directly, the interpreter
// calls into itself as the handler, so being hermetic is not
// required), and true simply means serviced python operator calls will
// be hermetic; in these cases it is expected to be functionally
// equivalent.
//
// - The second interpreter MUST see hasState true (as its requests will
// be forwarded to the first interpreter), but it is assumed that there
// is a synchronization between the interpreter initialization, and
// when we actually perform operations, so it is guaranteed to see
// hasState true.
//
// QED.
//
// This fastpath is currently disabled so that we can more easily test that
// hermetic mode works correctly even on stock build of PyTorch.
if (false && !haveState_.load(std::memory_order_relaxed))
return false;
return get_tls_state();
}
// Call this from the multipy/torchdeploy // codespell:ignore multipy
// top level
static void init_state();
private:
// This only flipped once from false to true during
// torchdeploy/multipy initialization, // codespell:ignore multipy
// and never again.
static std::atomic<bool> haveState_;
static bool get_tls_state();
};
} // namespace c10::impl

View File

@ -13,10 +13,11 @@ struct C10_API PyInterpreterHooksInterface {
// Get the PyInterpreter instance
// Stub implementation throws error when Python is not available
// We return nullptr rather than throwing an error since there are bits of c10
// that expect an empty PyObjectSlot when python is not available.
virtual PyInterpreter* getPyInterpreter() const {
return nullptr;
TORCH_CHECK(
false,
"PyTorch was compiled without Python support. "
"Cannot access Python interpreter from C++.");
}
};

View File

@ -2,7 +2,7 @@
namespace c10::impl {
PyObjectSlot::PyObjectSlot() : pyobj_(nullptr) {}
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
@ -10,9 +10,9 @@ PyObjectSlot::~PyObjectSlot() {
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(getGlobalPyInterpreter() != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*getGlobalPyInterpreter())
(*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
@ -25,7 +25,7 @@ void PyObjectSlot::maybe_destroy_pyobj() {
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return getGlobalPyInterpreter();
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
@ -35,7 +35,7 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = getGlobalPyInterpreter();
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}

View File

@ -1,21 +1,15 @@
#pragma once
#include <c10/core/impl/HermeticPyObjectTLS.h>
#include <c10/core/impl/PyInterpreter.h>
#include <c10/core/impl/PyInterpreterHooks.h>
#include <c10/util/python_stub.h>
#include <optional>
#include <atomic>
namespace c10::impl {
// Function pointer type for getting the global interpreter
using GetPyInterpreterFn = PyInterpreter* (*)();
// Global function pointer (set by csrc initialization)
C10_API extern GetPyInterpreterFn g_get_pyinterpreter_fn;
// Helper function to get the global interpreter
C10_API PyInterpreter* getGlobalPyInterpreter();
struct C10_API PyObjectSlot {
public:
PyObjectSlot();
@ -32,6 +26,8 @@ struct C10_API PyObjectSlot {
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
@ -41,16 +37,36 @@ struct C10_API PyObjectSlot {
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter / PyObj as they may be null
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
std::optional<PyObject*> check_pyobj() const {
impl::PyInterpreter* interpreter = getGlobalPyInterpreter();
if (interpreter == nullptr || pyobj_ == nullptr) {
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
return std::nullopt;
}
return _unchecked_untagged_pyobj();
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
}
PyInterpreter& load_pyobj_interpreter() const;
@ -60,6 +76,30 @@ struct C10_API PyObjectSlot {
void set_owns_pyobj(bool b);
private:
// This field contains the interpreter tag for this object. See
// Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
std::atomic<PyInterpreter*> pyobj_interpreter_;
// This field contains a reference to a PyObject representing this Tensor.
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// PyObject for it and set this field. This field does not have to be

View File

@ -316,7 +316,6 @@ set(GENERATED_CXX_PYTHON
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
"${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp"
)
set(GENERATED_H_PYTHON
@ -380,9 +379,6 @@ add_custom_command(
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.h"
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp"
${autograd_python}
${autograd_yaml}
${autograd_templates}

View File

@ -38,7 +38,7 @@ def unroll(num_unrolls, IndexType, InType, OutType):
code = []
if num_unrolls == 1:
code.append(" // tail loop")
code.append(f" // tail loop")
code.append(" if (j < end_offset) {")
else:
code.append(f" // unrolling {num_unrolls} times")

View File

@ -153,6 +153,7 @@ _ZN3c104impl12PyObjectSlot10owns_pyobjEv
_ZN3c104impl12PyObjectSlot19maybe_destroy_pyobjEv
_ZN3c104impl12PyObjectSlotC1Ev
_ZN3c104impl12PyObjectSlotD2Ev
_ZN3c104impl19HermeticPyObjectTLS13get_tls_stateEv
_ZN3c104impl20TorchDispatchModeTLS13any_modes_setEb
_ZN3c104impl23ExcludeDispatchKeyGuardC1ENS_14DispatchKeySetE
_ZN3c104impl23ExcludeDispatchKeyGuardD2Ev

View File

@ -40,34 +40,7 @@ extensions = [
"sphinx.ext.intersphinx",
] + (["breathe", "exhale"] if run_doxygen else [])
intersphinx_mapping = {"pytorch": ("https://docs.pytorch.org/docs/main", None)}
# Configure Sphinx warnings and error handling
suppress_warnings = [
"ref.citation",
"ref.footnote",
"ref.doc",
"toc.excluded",
"toc.not_readable",
"misc.highlighting_failure",
]
# Configure Breathe
breathe_show_define_initializer = True
breathe_show_enumvalue_initializer = True
breathe_default_members = ("members", "undoc-members")
# Fix for Python 3.10+ compatibility with exhale 2.3.0
# MutableMapping was moved from collections to collections.abc in Python 3.10
try:
import collections
from collections.abc import MutableMapping
if not hasattr(collections, "MutableMapping"):
collections.MutableMapping = MutableMapping
except ImportError:
pass
intersphinx_mapping = {"pytorch": ("https://pytorch.org/docs/main", None)}
# Setup absolute paths for communicating with breathe / exhale where
# items are expected / should be trimmed by.
@ -128,21 +101,6 @@ exhale_args = {
Welcome to the developer reference for the PyTorch C++ API.
"""
),
############################################################################
# Duplicate handling and error management. #
############################################################################
# Note: Using Doxyfile instead of stdin configuration
# "exhaleDoxygenStdin" is not compatible with "exhaleUseDoxyfile"
# Handle unresolved references more gracefully
"unabridgedOrphanKinds": {
"function",
"define",
"enum",
"enumvalue",
"typedef",
"variable",
},
"fullToctreeMaxDepth": 2,
}
# Tell sphinx what the primary language being documented is.

View File

@ -1093,9 +1093,6 @@ The set of leaf modules can be customized by overriding
```{eval-rst}
.. autofunction:: torch.fx.replace_pattern
```
```{eval-rst}
.. autofunction:: torch.fx.traceback.annotate
```
<!-- The experimental and passes submodules are missing docs. -->
<!-- Adding it here for coverage but this doesn't add anything to the -->

View File

@ -156,7 +156,6 @@ def get_generate_code_bin_outs():
"autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"],
"autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"],
"autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"],
"functionalization/generated/ViewMetaClassesPythonBinding.cpp": ["functionalization/generated/ViewMetaClassesPythonBinding.cpp"],
})
return outs

View File

@ -182,6 +182,7 @@ ignore = [
"SIM117",
"SIM118",
"UP007", # keep-runtime-typing
"UP038", # Was removed from newer versions, results in slower code
"UP045", # keep-runtime-typing
"TC006",
# TODO: Remove Python-3.10 specific suppressions

View File

@ -1704,18 +1704,7 @@ def main() -> None:
package_data = {
"torch": torch_package_data,
}
# some win libraries are excluded
# these are statically linked
exclude_windows_libs = [
"lib/dnnl.lib",
"lib/kineto.lib",
"lib/libprotobuf-lite.lib",
"lib/libprotobuf.lib",
"lib/libprotoc.lib",
]
exclude_package_data = {
"torch": exclude_windows_libs,
}
exclude_package_data = {}
if not BUILD_LIBTORCH_WHL:
package_data["torchgen"] = torchgen_package_data

View File

@ -1,7 +1,9 @@
if(WIN32)
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/${CMAKE_IMPORT_LIBRARY_PREFIX}torch_python${CMAKE_IMPORT_LIBRARY_SUFFIX}")
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib")
elseif(APPLE)
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.dylib")
else()
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}torch_python${CMAKE_SHARED_LIBRARY_SUFFIX}")
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so")
endif()
add_library(torch_python SHARED IMPORTED)

View File

@ -11,12 +11,7 @@ from typing import Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable import checkpoint
from torch.distributed._composable.replicate_with_fsdp import replicate
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
_CHECKPOINT_PREFIX,
apply_activation_checkpointing,
)
from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, OffloadPolicy
from torch.distributed.tensor import DTensor, init_device_mesh
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
@ -657,190 +652,5 @@ class TestReplicate1DTrainingCore(FSDPTest):
self.assertEqual(ref_loss, loss)
class TestReplicateTrainingCompose(FSDPTest):
@property
def world_size(self) -> int:
# Since these tests run with a larger transformer model, they may see
# some numeric drift with >2 GPUs
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
@compiled_fsdp_test(compile_compute_on_module=Transformer)
def test_train_parity_with_activation_checkpointing(self):
"""
Tests train parity against DDP when composing with activation
checkpointing.
"""
self.run_subtests(
{
"reshard_after_forward": [True, False],
"checkpoint_impl": ["composable", "utils", "wrapper"],
"module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"],
"test_device_type": [device_type.type],
},
self._test_train_parity_with_activation_checkpointing,
)
def _test_train_parity_with_activation_checkpointing(
self,
reshard_after_forward: Union[bool, int],
checkpoint_impl: str,
module_grouping: str,
test_device_type: str,
):
assert checkpoint_impl in ("composable", "utils", "wrapper")
testing_compile = replicate != torch.distributed._composable.replicate_with_fsdp
if testing_compile and checkpoint_impl == "composable":
return
torch.manual_seed(42)
vocab_size = 1024
with torch.device(device_type):
model_args = ModelArgs(
n_layers=3,
n_heads=4,
vocab_size=vocab_size,
max_seq_len=64,
dropout_p=0,
checkpoint_activations=(checkpoint_impl == "utils"),
# For the mem-efficient module grouping, we separate the
# embeddings from the output projection, which does not support
# weight tying
weight_tying=module_grouping != "mem_eff",
)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
# Apply activation checkpointing
prefixes_to_ignore = ()
if checkpoint_impl == "wrapper":
prefixes_to_ignore = (_CHECKPOINT_PREFIX,)
apply_activation_checkpointing(
model, check_fn=lambda m: isinstance(m, TransformerBlock)
)
elif checkpoint_impl == "composable":
for module in model.modules():
if isinstance(module, TransformerBlock):
checkpoint(module)
# Apply Replicate
device_mesh = init_device_mesh(
test_device_type,
(self.world_size, 1),
mesh_dim_names=("replicate", "shard"),
)
fsdp_kwargs = {
"reshard_after_forward": reshard_after_forward,
"device_mesh": device_mesh,
}
if module_grouping == "mem_eff":
assert model_args.n_layers == 3
replicate(model.layers[0], **fsdp_kwargs)
replicate([model.layers[1], model.layers[2]], **fsdp_kwargs)
replicate([model.tok_embeddings, model.pos_embeddings], **fsdp_kwargs)
# Embedding weights are not needed for embedding backward
model.tok_embeddings.set_unshard_in_backward(False)
replicate([model.norm, model.output], **fsdp_kwargs)
elif module_grouping == "mem_eff_weight_tied":
replicate([model.tok_embeddings, model.output], **fsdp_kwargs)
for layer in model.layers:
replicate(layer, **fsdp_kwargs)
elif module_grouping == "block":
for layer in model.layers:
replicate(layer, **fsdp_kwargs)
else:
raise NotImplementedError(f"Unknown module grouping: {module_grouping}")
replicate(model, **fsdp_kwargs)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.manual_seed(42 + self.rank)
# Reuse the same input across iterations to avoid loss explosion from
# trying to learn from random inputs
inp = torch.randint(0, vocab_size, (3, 64), device=device_type.type)
check_sharded_parity(
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
)
for iter_idx in range(10):
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
torch.manual_seed(iter_idx + 1) # for dropout determinism
losses.append(_model(inp).sum())
losses[-1].backward()
for param in ref_model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad)
param.grad.div_(self.world_size)
if not testing_compile:
check_sharded_parity(
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
)
self.assertEqual(losses[0], losses[1])
for _optim in (ref_optim, optim):
_optim.step()
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
if not testing_compile:
check_sharded_parity(
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
)
class TestReplicateSharedParams(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_train_parity_with_shared_params(self):
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
},
self._test_train_shared_params,
)
def _test_train_shared_params(
self,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=3, dropout_p=0.0, weight_tying=True)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
if use_activation_checkpointing:
checkpoint(module)
replicate(module, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(10):
inp = torch.randint(
0, model_args.vocab_size, (2, 16), device=device_type.type
)
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
losses.append(_model(inp).sum())
losses[-1].backward()
for param in ref_model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad)
param.grad.div_(self.world_size)
for _optim in (ref_optim, optim):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
_optim.step()
self.assertEqual(losses[0], losses[1])
if __name__ == "__main__":
run_tests()

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import unittest
import torch

View File

@ -1,4 +1,4 @@
# Owner(s): ["module: fsdp"]
# Owner(s): ["module: unknown"]
import functools
import gc
from typing import Union

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import gc
import unittest

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
from copy import copy

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import unittest
from dataclasses import dataclass
from typing import Any, Callable, cast, Union

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import unittest
import torch

View File

@ -1,4 +1,4 @@
# Owner(s): ["oncall: distributed"]
# Owner(s): ["module: unknown"]
import copy
import unittest

View File

@ -143,19 +143,6 @@ class FlightRecorderEventTest(TestCase):
match_one_event(e11, e12, membership, "0").state,
MatchState.FULLY_MATCHED,
)
e13 = create_one_event(
"gather",
("0", "default"),
[[4, 4]],
[[4, 4]],
"completed",
1,
output_dtypes="",
)
self.assertEqual(
match_one_event(e11, e13, membership, "0").state,
MatchState.FULLY_MATCHED,
)
def test_all_events(self):
for collective in sorted(COLLECTIVES):

View File

@ -202,62 +202,6 @@ class ScheduleTest(TestCase):
torch.distributed.destroy_process_group()
@parametrize(
"ScheduleClass",
[
Schedule1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS,
],
)
def test_schedule_eval_then_train(self, ScheduleClass):
"""
Test that simply runs evaluation followed by training.
"""
store = FakeStore()
torch.distributed.init_process_group(
backend="fake", rank=0, world_size=1, store=store
)
d_hid, batch_size = 512, 256
n_stages = 1
device = "cpu"
full_mod = MultiMLP(d_hid, n_layers=n_stages)
full_mod.to(device)
x = torch.randn(batch_size, d_hid, device=device)
target = torch.randn(batch_size, d_hid, device=device)
def loss_fn(y, target):
return torch.nn.functional.cross_entropy(y, target)
submod_name = "layers.0"
stage_module = full_mod.get_submodule(submod_name)
# Create a pipeline stage to wrap that submodule
num_microbatches = 2
stages = [PipelineStage(stage_module, 0, n_stages, device)]
if issubclass(ScheduleClass, PipelineScheduleSingle):
stages = stages[0]
# Attach to a schedule
schedule = ScheduleClass(stages, num_microbatches, loss_fn=loss_fn)
# Run eval
for _ in range(2):
# Zero gradients
stage_module.zero_grad()
losses = []
schedule.eval(x, target=target, losses=losses)
# Run training
try:
for _ in range(2):
losses = []
schedule.step(x, target=target, losses=losses)
finally:
torch.distributed.destroy_process_group()
def test_zero_bubble_schedule_errors_with_compile(self):
"""
Test that zero bubble schedules raise an error when used with torch.compile.

View File

@ -248,16 +248,6 @@ class TestDTensorDebugMode(TestCase):
"redistribute_input(1, [S(0)] -> [R])" in debug_mode.debug_string()
)
def test_debug_mode_higher_order_cond(self):
"""Test DebugMode with higher order operation."""
x = torch.randn(1, 8, requires_grad=True)
with DebugMode(record_torchfunction=True) as debug_mode:
torch.cond(torch.tensor(True), lambda x: x + 1, lambda x: x - 1, [x])
# Verify that cond operations are captured in debug mode
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -352,7 +352,7 @@ class MicroPipelineTPTest(TestCase):
@parametrize("scatter_dim", [0, 1, 2])
@fresh_cache()
def test_fuse_scaled_matmul_reduce_scatter(self, A_dims, scatter_dim):
if scatter_dim >= A_dims - 1:
if scatter_dim >= A_dims:
return
group = dist.group.WORLD
@ -402,7 +402,7 @@ class MicroPipelineTPTest(TestCase):
@runOnRocmArch(MI300_ARCH)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("scatter_dim", [0, 1])
@parametrize("scatter_dim", [0, 1, 2])
@fresh_cache()
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
self, scatter_dim

View File

@ -880,34 +880,6 @@ class DistMathOpsTest(DTensorTestBase):
out_full = out_dt.full_tensor()
self.assertEqual(global_bins, out_full)
@with_comms
def test_logsumexp(self):
mesh = self.build_device_mesh()
comm_mode = CommDebugMode()
inp = torch.rand(3, 5, device=self.device_type)
shard_dim = 0
input_dtensor = distribute_tensor(
inp, device_mesh=mesh, placements=[Shard(shard_dim)]
)
logsumexp_dims = [0, 1]
for dim in logsumexp_dims:
output = torch.logsumexp(inp, dim=dim)
with comm_mode:
output_dtensor = torch.logsumexp(input_dtensor, dim=dim)
if dim == shard_dim:
self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertEqual(
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
1,
)
self.assertTrue(output_dtensor.placements[0].is_replicate())
else:
self.assertEqual(comm_mode.get_total_counts(), 0)
self.assertTrue(output_dtensor.placements[0].is_shard(shard_dim))
self.assertEqual(output_dtensor.full_tensor(), output)
if __name__ == "__main__":
run_tests()

View File

@ -31,7 +31,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
skip_unless_torch_gpu,
with_comms,
)
from torch.utils._typing_utils import not_none
def get_generator_seed_for_device_type(device_type: str) -> int:
@ -550,9 +549,7 @@ class DistTensorRandomOpTest(DTensorTestBase):
# local_shard_list_on_dim[i] has the list of all shards on that dim
# as a tuple (local_shard_offset, local_shard_size)
dtensor_shape = dtensor.shape
local_shard_list_on_dim: list[list[tuple[int, int]]] = [
[(0, l)] for l in dtensor_shape
]
local_shard_list_on_dim = [[(0, l)] for l in dtensor_shape]
for idx, placement in enumerate(placements):
if isinstance(placement, Shard):
mesh_dim_size = device_mesh.size(idx)
@ -568,7 +565,7 @@ class DistTensorRandomOpTest(DTensorTestBase):
shard_idx_on_dim,
)
local_shard_list_on_dim[shard_dim].append(
(not_none(shard_offset), shard_size)
(shard_offset, shard_size)
)
local_shard_comb = itertools.product(*local_shard_list_on_dim)

View File

@ -691,25 +691,6 @@ class TestStridedSharding(DTensorTestBase):
)
self.assertEqual(full_tensor, x)
@with_comms
def test_2d_mesh_uneven_strided_shard(self):
mesh = init_device_mesh(
self.device_type,
(self.world_size // 2, 2),
mesh_dim_names=("fsdp", "tp"),
)
for size in (2, 3, 5, 11):
tensor = torch.arange(size, device=self.device_type).view(1, -1)
dtensor = distribute_tensor(
tensor,
device_mesh=mesh,
placements=(Replicate(), Replicate()),
).redistribute(
mesh, placements=(_StridedShard(dim=1, split_factor=2), Shard(1))
)
self.assertEqual(dtensor.full_tensor(), tensor)
class Test2DStridedLocalShard(DTensorTestBase):
@property

View File

@ -8,7 +8,6 @@ import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch._C._distributed_c10d import Backend as C10dBackend
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._mesh_layout import _MeshLayout as _Layout
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
from torch.distributed.distributed_c10d import (
_get_default_group,
@ -28,7 +27,7 @@ from torch.distributed.tensor._collective_utils import (
)
from torch.distributed.tensor.placement_types import _Partial, Shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase
from torch.testing._internal.common_utils import run_tests, TEST_XPU
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
@ -863,7 +862,7 @@ class TestDeviceMeshGetItem(DTensorTestBase):
# Test flatten into an existing mesh_dim_name inside the mesh
with self.assertRaisesRegex(
ValueError,
RuntimeError,
"already exists for submesh of the DeviceMesh",
):
mesh_3d._flatten("dp")
@ -903,18 +902,6 @@ class TestDeviceMeshGetItem(DTensorTestBase):
cp_tp_mesh._flatten("dummy")
self.assertEqual(mesh_3d["dummy"].mesh_dim_names[0], "dummy")
# Test flatten into an existing mesh_dim_name inside the mesh
with self.assertRaisesRegex(
ValueError,
"dp already exists for submesh of the DeviceMesh",
):
mesh_3d._flatten("dp")
with self.assertRaisesRegex(
ValueError,
"Flatten mesh with mesh_dim_name dp_tp has been created before",
):
mesh_3d["cp", "tp"]._flatten("dp_tp")
@with_comms(eager_init=True)
def test_flatten_mesh_4d(self):
mesh_shape = (2, 2, 2, 1)
@ -1301,204 +1288,5 @@ class DeviceMeshCollectiveTest(DTensorTestBase):
self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
class CuTeLayoutTest(TestCase):
def test_coalesce(self):
# ((3,2),(2,1)) -> (6,1)
l = _Layout((3, 2), (2, 1))
l = l.coalesce()
self.assertEqual(list(l.sizes_and_strides), [(6, 1)])
# ((2,12),(3,4),(4,1)) -> (24,1)
l = _Layout((2, 3, 4), (12, 4, 1))
l = l.coalesce()
self.assertEqual(list(l.sizes_and_strides), [(24, 1)])
def test_coalesce_non_coalescible(self):
# ((3,4),(2,1)) stays as-is (4 ≠ 2*1)
l = _Layout((3, 2), (4, 1))
l = l.coalesce()
self.assertEqual(list(l.sizes_and_strides), [(3, 4), (2, 1)])
def test_complement_n_group_layout(self):
# complement((4,2), 8) = (2,1); together form (8,1)
pg_layout = _Layout(
(4,),
(2,),
)
outer = pg_layout.complement(world_size=8)
self.assertEqual(list(outer.sizes_and_strides), [(2, 1)])
self.assertEqual(
pg_layout.all_ranks_from_zero(),
[0, 2, 4, 6],
)
groups = [
[o + i for i in pg_layout.all_ranks_from_zero()]
for o in outer.all_ranks_from_zero()
]
self.assertEqual(
groups,
[
[0, 2, 4, 6],
[1, 3, 5, 7],
],
)
self.assertEqual(
pg_layout.global_ranks(8),
[
[0, 2, 4, 6],
[1, 3, 5, 7],
],
)
# complement((4,2), 16) = ((2,8), (2,1)); together form (16,1)
outer = pg_layout.complement(world_size=16)
self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)])
self.assertEqual(
outer.all_ranks_from_zero(),
[0, 1, 8, 9],
)
self.assertEqual(
pg_layout.global_ranks(16),
[
[0, 2, 4, 6],
[1, 3, 5, 7],
[8, 10, 12, 14],
[9, 11, 13, 15],
],
)
# Complement ((2,4), (2,1)) under world_size=16 → complement ((2,8), (2,2))
pg_layout = _Layout((2, 2), (4, 1))
self.assertEqual(
pg_layout.all_ranks_from_zero(),
[0, 1, 4, 5],
)
outer = pg_layout.complement(world_size=16)
self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 2)])
self.assertEqual(
outer.all_ranks_from_zero(),
[0, 2, 8, 10],
)
self.assertEqual(
pg_layout.global_ranks(16),
[
[0, 1, 4, 5],
[2, 3, 6, 7],
[8, 9, 12, 13],
[10, 11, 14, 15],
],
)
# Test layout_to_global_ranks and layout_to_all_ranks_from_zero
pg_layout = _Layout((2, 2), (4, 2))
self.assertEqual(
pg_layout.all_ranks_from_zero(),
[0, 2, 4, 6],
)
self.assertEqual(
pg_layout.global_ranks(16),
[
[0, 2, 4, 6],
[1, 3, 5, 7],
[8, 10, 12, 14],
[9, 11, 13, 15],
],
)
outer = pg_layout.complement(world_size=16)
self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)])
# Test when stride is not monotonically decreasing, the complement layout
# is same as the one sorted its stride.
pg_layout_r = _Layout((2, 2), (2, 4))
outer = pg_layout_r.complement(world_size=16)
self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)])
self.assertEqual(
pg_layout_r.global_ranks(16),
[
[0, 4, 2, 6],
[1, 5, 3, 7],
[8, 12, 10, 14],
[9, 13, 11, 15],
],
)
# Test just all_ranks_from_zero and global_ranks.
pg_layout = _Layout((4,), (2,))
self.assertEqual(
pg_layout.all_ranks_from_zero(),
[0, 2, 4, 6],
)
self.assertEqual(
pg_layout.global_ranks(16),
[
[0, 2, 4, 6],
[1, 3, 5, 7],
[8, 10, 12, 14],
[9, 11, 13, 15],
],
)
def test_composition(self):
# self = ((4,2), (2,1)), l = (2,1) → self o l = (2,1)
orig_l = _Layout((4, 2), (2, 1))
right_l = _Layout((2,), (1,))
composed_layout = orig_l.composition(right_l)
self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 1)])
self.assertEqual(
composed_layout.global_ranks(8),
[
[0, 1],
[2, 3],
[4, 5],
[6, 7],
],
)
# self = (4,2), l = (2,1) → self o l = (2,2)
orig_l = _Layout((4,), (2,))
right_l = _Layout((2,), (1,))
composed_layout = orig_l.composition(right_l)
self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 2)])
self.assertEqual(
composed_layout.global_ranks(8),
[
[0, 2],
[1, 3],
[4, 6],
[5, 7],
],
)
# self = (4,2), l = ((2,2), (2,1)) → self o l = ((2,4), (2,2))
# This is to mimic the un-flatten from a 2D mesh to a 1D mesh.
right_l = _Layout((2, 2), (2, 1))
composed_layout = orig_l.composition(right_l)
self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 4), (2, 2)])
self.assertEqual(
composed_layout[0].global_ranks(8),
[
[0, 4],
[1, 5],
[2, 6],
[3, 7],
],
)
self.assertEqual(
composed_layout[1].global_ranks(8),
[
[0, 2],
[1, 3],
[4, 6],
[5, 7],
],
)
# Error case.
orig_l = _Layout((4, 2), (4, 1))
with self.assertRaises(
AssertionError,
):
right_l = _Layout((2,), (3,))
orig_l.composition(right_l)
if __name__ == "__main__":
run_tests()

View File

@ -299,33 +299,28 @@ class NVSHMEMAll2AllTest(MultiProcContinuousTest):
torch.randn(max_inp_numel, dtype=dtype, device=self.device)
)
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
in_splits = symm_mem.empty(
self.world_size, dtype=torch.int64, device=self.device
)
out_splits_offsets = symm_mem.empty(
(2, self.world_size), dtype=torch.int64, device=self.device
in_out_splits = symm_mem.empty(
(3, self.world_size), dtype=torch.int64, device=self.device
)
# Row 0 is input splits
in_splits.copy_(inp_splits)
in_out_splits[0].copy_(inp_splits)
# Sync all ranks to ensure remote tensors are allocated
dist.barrier()
torch.ops.symm_mem.all_to_all_vdev(
inp, out, in_splits, out_splits_offsets, group_name
)
torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name)
# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_splits, inp_splits)
torch.testing.assert_close(in_out_splits[0], inp_splits)
# Check output splits (row 1)
torch.testing.assert_close(out_splits_offsets[0], out_splits)
torch.testing.assert_close(in_out_splits[1], out_splits)
# Check output offsets (row 2)
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
# output offsets from `all_to_all_vdev` is exclusive scan
self.assertEqual(out_splits_offsets[1][0], 0)
torch.testing.assert_close(out_splits_offsets[1][1:], out_offsets[:-1])
self.assertEqual(in_out_splits[2][0], 0)
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)

View File

@ -2,8 +2,6 @@
# To run:
# python test/distributed/test_nvshmem_triton.py
import sys
import triton.language as tl
import torch
@ -11,7 +9,6 @@ import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
from torch._inductor.runtime.triton_compat import triton
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem
from torch.testing._internal.common_distributed import MultiProcContinuousTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -23,9 +20,12 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.inductor_utils import IS_H100, requires_triton
if not symm_mem.is_nvshmem_available():
print("NVSHMEM not available, skipping tests")
sys.exit(0)
# Decorators
def requires_nvshmem():
return skip_but_pass_in_sandcastle_if(
not symm_mem.is_nvshmem_available(),
"test_nvshmem requires NVSHMEM, skipping tests",
)
def requires_h100():
@ -41,11 +41,8 @@ device_module = torch.get_device_module(device_type)
# Shared Triton JIT kernels
@requires_nvshmem
@triton.jit
def my_put_kernel(
def nvshmem_put_kernel(
dest,
src,
nelems,
@ -54,9 +51,8 @@ def my_put_kernel(
nvshmem.put(dest, src, nelems, pe)
@requires_nvshmem
@triton.jit
def my_get_kernel(
def nvshmem_get_kernel(
dest,
src,
nelems,
@ -65,9 +61,8 @@ def my_get_kernel(
nvshmem.get(dest, src, nelems, pe)
@requires_nvshmem
@triton.jit
def my_putmem_signal_block_kernel(
def nvshmem_putmem_signal_block_kernel(
dst,
src,
size_bytes,
@ -79,15 +74,13 @@ def my_putmem_signal_block_kernel(
nvshmem.putmem_signal_block(dst, src, size_bytes, signal, sig_val, sig_op, peer)
@requires_nvshmem
@triton.jit
def my_signal_wait_until_kernel(signal, cmp_op, cmp_val):
def nvshmem_signal_wait_until_kernel(signal, cmp_op, cmp_val):
nvshmem.signal_wait_until(signal, cmp_op, cmp_val)
@requires_nvshmem
@triton.jit
def my_signal_op_kernel(
def nvshmem_signal_op_kernel(
sig_addr,
signal,
sig_op,
@ -96,9 +89,8 @@ def my_signal_op_kernel(
nvshmem.signal_op(sig_addr, signal, sig_op, peer)
@requires_nvshmem
@triton.jit
def my_wait_until_kernel(
def nvshmem_wait_until_kernel(
ivar,
cmp_op,
cmp_val,
@ -106,15 +98,13 @@ def my_wait_until_kernel(
nvshmem.wait_until(ivar, cmp_op, cmp_val)
@requires_nvshmem
@triton.jit
def my_fence_kernel():
def nvshmem_fence_kernel():
nvshmem.fence()
@requires_nvshmem
@triton.jit
def my_put_with_fence_kernel(
def nvshmem_put_with_fence_kernel(
dst1,
src1,
dst2,
@ -136,9 +126,8 @@ def my_put_with_fence_kernel(
nvshmem.put(flag_dst, flag_src, 1, peer)
@requires_nvshmem
@triton.jit
def my_put_with_quiet_kernel(
def nvshmem_put_with_quiet_kernel(
dst,
src,
flag_dst,
@ -155,9 +144,8 @@ def my_put_with_quiet_kernel(
nvshmem.put(flag_dst, flag_src, 1, peer)
@requires_nvshmem
@triton.jit
def my_barrier_test_kernel(
def nvshmem_barrier_test_kernel(
dst,
src,
nelems,
@ -190,15 +178,13 @@ def my_barrier_test_kernel(
tl.store(p_dst, received + 1)
@requires_nvshmem
@triton.jit
def my_barrier_all_kernel():
def nvshmem_barrier_all_kernel():
nvshmem.barrier_all()
@requires_nvshmem
@triton.jit
def my_sync_test_kernel(
def nvshmem_sync_test_kernel(
local_data,
remote_data,
nelems,
@ -224,9 +210,8 @@ def my_sync_test_kernel(
# because sync_all() made those local stores visible
@requires_nvshmem
@triton.jit
def my_alltoall_kernel(
def nvshmem_alltoall_kernel(
team_handle,
dst,
src,
@ -235,9 +220,8 @@ def my_alltoall_kernel(
nvshmem.alltoall(team_handle, dst, src, nelems_per_pe)
@requires_nvshmem
@triton.jit
def my_broadcast_kernel(
def nvshmem_broadcast_kernel(
team_handle,
dst,
src,
@ -247,9 +231,8 @@ def my_broadcast_kernel(
nvshmem.broadcast(team_handle, dst, src, nelems, pe_root)
@requires_nvshmem
@triton.jit
def my_reduce_kernel(
def nvshmem_reduce_kernel(
team_handle,
dest_tensor,
source_tensor,
@ -260,6 +243,7 @@ def my_reduce_kernel(
@instantiate_parametrized_tests
@requires_nvshmem()
class NVSHMEMTritonTest(MultiProcContinuousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
@ -278,6 +262,9 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
torch.manual_seed(42 + self.rank)
self._init_device()
# Enable NVSHMEM for Triton
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -307,11 +294,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
peer = 1 - rank
if rank == 0:
# Rank 0 puts its data to Rank 1
my_put_kernel[(1,)](
nvshmem_put_kernel[(1,)](
dst,
src,
nelems,
peer,
extern_libs=nvshmem_lib,
)
# Synchronize after operation
@ -331,6 +319,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -352,11 +341,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
peer = 1 - rank
if rank == 1:
# Rank 1 gets data from rank 0 using tensor-aware API
my_get_kernel[(1,)](
nvshmem_get_kernel[(1,)](
out,
inp,
numel,
peer,
extern_libs=nvshmem_lib,
)
if rank == 1:
torch.testing.assert_close(
@ -370,6 +360,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -392,11 +383,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
peer = (rank - 1) % world_size
# All ranks execute the get operation using tensor-aware API
my_get_kernel[(1,)](
nvshmem_get_kernel[(1,)](
out,
inp,
numel,
peer,
extern_libs=nvshmem_lib,
)
expected_value = peer
@ -411,6 +403,8 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -437,7 +431,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
if rank == 0:
# Rank 0 puts into Rank 1
my_putmem_signal_block_kernel[(1, 1, 1)](
nvshmem_putmem_signal_block_kernel[(1, 1, 1)](
out,
inp,
size_bytes=msg_size_bytes,
@ -445,14 +439,16 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
sig_val=SIGNAL_VAL,
sig_op=NVSHMEM_SIGNAL_SET,
peer=peer,
extern_libs=nvshmem_lib,
)
if rank == 1:
# Wait until signal flag is set by Rank 0
my_signal_wait_until_kernel[(1,)](
nvshmem_signal_wait_until_kernel[(1,)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=SIGNAL_VAL,
extern_libs=nvshmem_lib,
)
# After wait completes, verify data and flag contents
torch.testing.assert_close(
@ -469,6 +465,8 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -495,7 +493,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
if rank == 0:
# Rank 0 puts into Rank 1
my_putmem_signal_block_kernel[(1, 1, 1)](
nvshmem_putmem_signal_block_kernel[(1, 1, 1)](
out,
inp,
size_bytes=msg_size_bytes,
@ -503,13 +501,15 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
sig_val=SIGNAL_VAL,
sig_op=NVSHMEM_SIGNAL_ADD,
peer=peer,
extern_libs=nvshmem_lib,
)
if rank == 1:
my_signal_wait_until_kernel[(1, 1, 1)](
nvshmem_signal_wait_until_kernel[(1, 1, 1)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=SIGNAL_VAL,
extern_libs=nvshmem_lib,
)
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
@ -525,6 +525,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
@ -543,12 +544,15 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
[FLAG_FINAL_VALUE], dtype=torch.int32, device=self.device
)
nvshmem_barrier_all_kernel[(1,)](extern_libs=nvshmem_lib)
if rank == 0:
# Rank 0 (the waiter)
my_wait_until_kernel[(1,)](
nvshmem_wait_until_kernel[(1,)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=FLAG_FINAL_VALUE,
extern_libs=nvshmem_lib,
)
# Verification
@ -560,11 +564,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
if rank == 1:
# Rank 1 (the signaler)
# Launch a kernel to put the value to Rank 0's flag tensor.
my_put_kernel[(1,)](
nvshmem_put_kernel[(1,)](
flag, # Destination symmetric tensor on the remote PE
expected_flag, # Source data tensor (local)
1, # Number of elements
peer, # The target PE (Rank 0)
extern_libs=nvshmem_lib,
)
@skipIfRocm
@ -572,6 +577,8 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
@requires_h100()
def test_triton_signal_wait_until(self) -> None:
self._init_device()
# Enable NVSHMEM for Triton
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -601,7 +608,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
if rank == 0:
# Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag
my_putmem_signal_block_kernel[(1, 1, 1)](
nvshmem_putmem_signal_block_kernel[(1, 1, 1)](
out,
inp,
size_bytes=msg_size_bytes,
@ -609,13 +616,15 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
sig_val=COMPLETION_FLAG_VAL,
sig_op=NVSHMEM_SIGNAL_SET,
peer=peer,
extern_libs=nvshmem_lib,
)
elif rank == 1:
# Consumer (rank 1): Waits on the signal variable using `signal_wait_until`.
my_signal_wait_until_kernel[(1, 1, 1)](
nvshmem_signal_wait_until_kernel[(1, 1, 1)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=COMPLETION_FLAG_VAL,
extern_libs=nvshmem_lib,
)
# After the wait returns, verify data and flag
torch.testing.assert_close(
@ -642,6 +651,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
"""
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -672,7 +682,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
NVSHMEM_CMP_EQ = 0 # compare equal
if rank == 0:
my_put_with_fence_kernel[(1,)](
nvshmem_put_with_fence_kernel[(1,)](
out1,
inp1,
out2,
@ -681,13 +691,15 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
flag_update_val,
nelems=numel,
peer=peer,
extern_libs=nvshmem_lib,
)
elif rank == 1:
# Wait until flag is set by Rank 0
my_wait_until_kernel[(1,)](
nvshmem_wait_until_kernel[(1,)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=flag_val,
extern_libs=nvshmem_lib,
)
# Verify ordered data arrival.
@ -707,6 +719,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
def test_triton_quiet(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -732,19 +745,21 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
dist.barrier()
if rank == 1:
my_put_with_quiet_kernel[(1,)](
nvshmem_put_with_quiet_kernel[(1,)](
out,
inp,
flag,
flag_update_val,
nelems=numel,
peer=peer,
extern_libs=nvshmem_lib,
)
elif rank == 0:
my_wait_until_kernel[(1,)](
nvshmem_wait_until_kernel[(1,)](
flag,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=flag_val,
extern_libs=nvshmem_lib,
)
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
@ -757,6 +772,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
def test_triton_barrier(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -768,10 +784,11 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
symm_mem.rendezvous(src, group=group_name)
symm_mem.rendezvous(dst, group=group_name)
my_barrier_test_kernel[(1,)](
nvshmem_barrier_test_kernel[(1,)](
dst,
src,
nelems=numel,
extern_libs=nvshmem_lib,
launch_cooperative_grid=True,
num_ctas=1,
)
@ -793,6 +810,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -806,10 +824,11 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
symm_mem.rendezvous(remote_data, group=group_name)
# Launch kernel with cooperative grid
my_sync_test_kernel[(1,)](
nvshmem_sync_test_kernel[(1,)](
local_data,
remote_data,
nelems=numel,
extern_libs=nvshmem_lib,
launch_cooperative_grid=True,
num_ctas=1,
)
@ -836,6 +855,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
def test_triton_alltoall(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
world_size = dist.get_world_size()
@ -860,11 +880,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
dist.barrier()
team_handle = 0 # NVSHMEM_TEAM_WORLD handle is 0
# Launch the kernel using new tensor-aware API
my_alltoall_kernel[(1,)](
nvshmem_alltoall_kernel[(1,)](
team_handle,
dst,
src,
nelems_per_pe,
extern_libs=nvshmem_lib,
launch_cooperative_grid=True,
)
# Synchronize after alltoall
@ -883,6 +904,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
def test_triton_broadcast(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
@ -913,12 +935,13 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
# Execute broadcast
team_handle = 0 # NVSHMEM_TEAM_WORLD
my_broadcast_kernel[(1,)](
nvshmem_broadcast_kernel[(1,)](
team_handle,
dst,
src,
nelems,
pe_root,
extern_libs=nvshmem_lib,
launch_cooperative_grid=True,
)
@ -951,6 +974,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
def test_triton_sum_reduce(self, dtype) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
world_size = dist.get_world_size()
@ -977,12 +1001,13 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
# Execute sum reduction across all ranks
team_handle = 0 # NVSHMEM_TEAM_WORLD
my_reduce_kernel[(1,)](
nvshmem_reduce_kernel[(1,)](
team_handle,
dst,
src,
nreduce,
operation="sum",
extern_libs=nvshmem_lib,
launch_cooperative_grid=True,
)
@ -1013,6 +1038,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
def test_triton_minmax_reduce(self, dtype) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
world_size = dist.get_world_size()
@ -1054,21 +1080,23 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
dist.barrier()
# Execute MIN reduction
team_handle = 0
my_reduce_kernel[(1,)](
nvshmem_reduce_kernel[(1,)](
team_handle,
dst_min,
src_min,
nreduce,
operation="min",
extern_libs=nvshmem_lib,
launch_cooperative_grid=True,
)
# Execute MAX reduction
my_reduce_kernel[(1,)](
nvshmem_reduce_kernel[(1,)](
team_handle,
dst_max,
src_max,
nreduce,
operation="max",
extern_libs=nvshmem_lib,
launch_cooperative_grid=True,
)
dist.barrier()
@ -1099,6 +1127,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
def test_triton_prod_reduce(self, dtype) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.distributed_c10d._get_default_group().group_name
symm_mem.enable_symm_mem_for_group(group_name)
world_size = dist.get_world_size()
@ -1138,12 +1167,13 @@ class NVSHMEMTritonTest(MultiProcContinuousTest):
# Execute product reduction across all ranks
team_handle = 0 # NVSHMEM_TEAM_WORLD
my_reduce_kernel[(1,)](
nvshmem_reduce_kernel[(1,)](
team_handle,
dst,
src,
nreduce,
operation="prod",
extern_libs=nvshmem_lib,
launch_cooperative_grid=True,
)

View File

@ -505,7 +505,7 @@ class AsyncTPTest(MultiProcContinuousTest):
not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch"
)
@skip_if_lt_x_gpu(2)
@parametrize("scatter_dim", [0, 1, 2])
@parametrize("scatter_dim", [0, 1])
def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
self._init_process()

View File

@ -519,7 +519,11 @@ class AOTAutogradCacheTests(InductorTestCase):
@functorch_config.patch(
{"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True}
)
def test_view_replay(self):
def test_view_replay_bypass(self):
"""
Should bypass when view replay is turned on
"""
def fn(a):
tmp = a.detach()
a.mul_(2)
@ -527,25 +531,10 @@ class AOTAutogradCacheTests(InductorTestCase):
with torch.autograd._force_original_view_tracking(True):
compiled_fn = torch.compile(fn)
compiled_fn(torch.rand(2, 3))
def run_and_check(miss, hit, bypass):
self._clear_dynamo_and_codecache()
inp = torch.rand(2, 3)
compiled_inp = inp.clone().detach()
with torch.autograd._force_original_view_tracking(True):
out = fn(inp)
compiled_out = compiled_fn(compiled_inp)
self.assertEqual(out, compiled_out)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass)
run_and_check(miss=1, hit=0, bypass=0)
run_and_check(miss=1, hit=1, bypass=0)
run_and_check(miss=1, hit=2, bypass=0)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", True)

View File

@ -1,10 +1,8 @@
# Owner(s): ["module: dynamo"]
import inspect
import os
import pickle
from contextlib import contextmanager
from unittest.mock import patch
import torch
import torch._dynamo.testing
@ -31,27 +29,8 @@ class CustomCompiledFunction(torch._dynamo.aot_compile.SerializableCallable):
@classmethod
def serialize_compile_artifacts(cls, fn) -> bytes:
import sympy
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import Options
state = fn.__dict__.copy()
graph_reducer_override = GraphPickler.reducer_override
def _graph_reducer_override(self, obj):
if (
inspect.isclass(obj)
and issubclass(obj, sympy.Function)
and hasattr(obj, "_torch_unpickler")
):
return obj._torch_unpickler, (obj._torch_handler_name,)
if isinstance(obj, FakeTensorMode):
return type(None), ()
return graph_reducer_override(self, obj)
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
state["gm"] = GraphPickler.dumps(state["gm"], Options(ops_filter=None))
state["gm"] = GraphPickler.dumps(state["gm"])
return pickle.dumps(state)
@classmethod
@ -75,14 +54,6 @@ class SimpleLinearModule(torch.nn.Module):
return self.linear(x)
class RepeatInterleaveModule(torch.nn.Module):
def forward(self, x):
chunk = x.chunk(2, dim=-1)
y = chunk[0]
y_repeat = y.repeat_interleave(2, dim=-1)
return y_repeat
@torch._dynamo.config.patch("enable_aot_compile", True)
@instantiate_parametrized_tests
class TestAOTCompile(torch._inductor.test_case.TestCase):
@ -143,34 +114,6 @@ class TestAOTCompile(torch._inductor.test_case.TestCase):
actual = compiled_fn(mod, *inputs)
self.assertEqual(expected, actual)
def test_aot_compile_repeat_interleave(self):
mod = RepeatInterleaveModule()
def backend(gm, example_inputs):
return CustomCompiledFunction(gm, example_inputs)
inputs = (torch.randn(2, 4),)
# The first dim should be dynamic to repro the issue of repeat_interleave
# torch._dynamo.mark_dynamic(inputs[0], [0])
compiled_fn = torch.compile(
mod,
fullgraph=True,
backend=backend,
).forward.aot_compile((inputs, {}))
expected = mod(*inputs)
actual = compiled_fn(mod, *inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(mod, *inputs)
self.assertEqual(expected, actual)
def test_decorated_function_aot(self):
def check_inputs(fn):
def _fn(*args, **kwargs):

View File

@ -80,7 +80,7 @@ def fn():
self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)
@unittest.skipIf(
sys.version_info >= (3, 11),
sys.version_info < (3, 10) or sys.version_info >= (3, 11),
"linetable test for Python 3.10",
)
def test_linetable_310_writer(self):
@ -95,6 +95,19 @@ def fn():
result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
self.assertTrue(result[1] == fn.__code__.co_linetable)
@unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10")
def test_lnotab_writer(self):
def fn():
a = 10
b = 20
c = a + b
f = "lnotab_writer"
return f"Test if {f} generates correct co_lnotab: {c}"
inst = dis.get_instructions(fn)
result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
self.assertTrue(result[1] == fn.__code__.co_lnotab)
def test_if_tensor_is_none(self):
"""
Python 3.11 adds new jump instructions that check if

View File

@ -410,6 +410,10 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
combs.append(torch.ones(size))
return combs
@unittest.skipIf(
sys.version_info < (3, 10),
"itertools.pairwise was added at Python 3.10",
)
@make_test
def test_itertools_pairwise(a):
pairs = []
@ -4694,6 +4698,10 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(lst), 2)
self.assertEqual(lst[0], lst[1])
@unittest.skipIf(
sys.version_info < (3, 10),
"zip strict kwargs not implemented for Python < 3.10",
)
def test_zip_strict(self):
def fn(x, ys, zs):
x = x.clone()

View File

@ -8005,11 +8005,8 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
torch._dynamo.decorators.mark_unbacked(b, 1)
func(a, b)
func(torch.rand(4, 5), torch.rand(4, 5))
# This does not raise an error right now because of a recompilation.
# https://github.com/pytorch/pytorch/issues/163785
# with self.assertRaises(AssertionError):
# func(torch.rand(1, 1), torch.rand(2, 1))
func(torch.rand(1, 1), torch.rand(2, 1))
with self.assertRaises(RuntimeError):
func(torch.rand(1, 1), torch.rand(2, 1))
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_sym_constrain_range_on_replaced_unbacked_symbol(self):

View File

@ -443,15 +443,59 @@ def run(cnt):
f(t(2, 4), t(2, 2))
f(t(4, 2), t(2, 2))
# with both default remote present, we ignore extra remote.
# with default remote (dynamic x) + extra remote (dynamic y),
# we should be able to wobble x & y with no recompiles.
self.reset()
cnts.clear()
with torch.compiler.config.patch(pgo_extra_read_key="sticky_1"):
f(t(2, 2), t(2, 2))
f(t(6, 8), t(2, 2))
f(t(2, 4), t(4, 2))
f(t(4, 2), t(2, 4))
self.assertEqual(cnts.frame_count, 1)
f(t(2, 2), t(2, 4))
self.assertEqual(cnts.frame_count, 2)
def test_profile_merges(self):
from torch._dynamo.pgo import auto_dynamic, merge_pgo_entry
@torch.compile(backend="eager", fullgraph=True)
def f(ints, t_scalar, tensors):
# arbitrary compute
return ints[0] + ints[1], t_scalar + 1, [t + 1 for t in tensors]
# single static run
f(
[0, 2],
torch.tensor(0),
[
torch.randn(2),
torch.randn(2, 2),
torch.randn(4, 4),
],
)
# collect profiles
profile = next(
iter(torch._dynamo.pgo.get_code_state().values())
).automatic_dynamic
i0, i1 = profile["L['ints'][0]"], profile["L['ints'][1]"]
ts = profile["L['t_scalar]"]
t0, t1, t2 = (
profile["L['tensors'][0]"],
profile["L['tensors'][1]"],
profile["L['tensors'][2]"],
)
# merging same scalar, or tensor into scalar -> no-op
merge_pgo_entry(i0, i0)
merge_pgo_entry(ts, i0)
merge_pgo_entry(t0, i0)
self.assertEqual(i0.scalar, 0)
# merging different scalars -> dynamic
merge_pgo_entry(i1, i0)
self.assertEqual(i0.scalar, auto_dynamic)
# merging different rank tensors -> static
merge_pgo_entry(t0, t2)
self.assertEqual(t2.size, (4, 4))
# merging same rank tensors -> dynamic
merge_pgo_entry(t1, t2)
self.assertEqual(t2.size, (auto_dynamic, auto_dynamic))
if __name__ == "__main__":

View File

@ -32,8 +32,7 @@ class TestExperiment(TestCase):
m = Module()
example_inputs = (torch.randn(3),)
m(*example_inputs)
with torch._export.config.patch(use_new_tracer_experimental=True):
ep = torch.export.export(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
joint_ep = _export_forward_backward(ep)
self.assertExpectedInline(
str(joint_ep.graph_module.code).strip(),

View File

@ -21,7 +21,6 @@ from unittest.mock import MagicMock, patch
import torch
import torch._dynamo as torchdynamo
import torch.fx.traceback as fx_traceback
import torch.nn.functional as F
import torch.utils._pytree as pytree
from functorch.experimental.control_flow import cond, map
@ -1087,93 +1086,6 @@ graph():
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
self.assertEqual(gm(*args), m(*args))
# stride() is called for an undefined tensor
@testing.expectedFailureCppRuntimeNonStrict
def test_native_multi_attention_head(self):
embed_dim = 64
num_heads = 4
bs = 16
sl = 8
device = "cpu"
q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
k = q
v = q
qkv = torch.nn.Linear(
embed_dim, 3 * embed_dim, device=device, dtype=torch.float32
)
proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32)
class NativeMHA(torch.nn.Module):
def __init__(
self,
embed_dim,
num_heads,
qkv,
proj,
need_weights,
average_attn_weights,
mask_type,
):
super().__init__()
self.qkv = qkv
self.proj = proj
self.embed_dim = embed_dim
self.num_heads = num_heads
self.need_weights = need_weights
self.average_attn_weights = average_attn_weights
self.mask_type = mask_type
def forward(self, q, k, v, key_padding_mask):
return torch._native_multi_head_attention(
q,
k,
v,
self.embed_dim,
self.num_heads,
self.qkv.weight,
self.qkv.bias,
self.proj.weight,
self.proj.bias,
key_padding_mask,
need_weights=False,
average_attn_weights=False,
mask_type=1, # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask
)
for mask_type in (0, 1):
for need_weights in (True, False):
for average_attn_weights in (True, False):
npt = NativeMHA(
embed_dim=embed_dim,
num_heads=num_heads,
qkv=qkv,
proj=proj,
need_weights=need_weights,
average_attn_weights=average_attn_weights,
mask_type=mask_type,
)
sample_input = (q, k, v, None)
ep = export(
npt,
args=sample_input,
dynamic_shapes={
"q": {
0: Dim("dim0_q", max=1024),
},
"k": {
0: Dim("dim0_k", max=1024),
},
"v": {
0: Dim("dim0_v", max=1024),
},
"key_padding_mask": None,
},
)
self.assertEqual(ep.module()(*sample_input), npt(*sample_input))
def test_unused_constant(self):
class M(torch.nn.Module):
def forward(self, x):
@ -2007,8 +1919,8 @@ class GraphModule(torch.nn.Module):
# z = 3
return x + y + z
with self.assertWarnsRegex(
UserWarning,
with self.assertRaisesRegex(
ValueError,
"The tensor attribute self.buf was assigned during export",
):
export(M(), (torch.randn(2, 3),), strict=False)
@ -2065,8 +1977,8 @@ class GraphModule(torch.nn.Module):
# z = 3 + 3
return x + y + z
with self.assertWarnsRegex(
UserWarning,
with self.assertRaisesRegex(
ValueError,
"The tensor attributes self.tensors\\[0\\], self.tensors\\[1\\] were assigned during export",
):
export(M(), (torch.randn(2, 3),), strict=False)
@ -15159,39 +15071,6 @@ def forward(self, x):
test_serdes=True,
)
# TODO: following tests should be fixed
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureTrainingIRToRunDecompNonStrict
def test_preserve_annotation(self):
class M(torch.nn.Module):
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
with fx_traceback.annotate({"fdsp_bucket": 0}):
x = x + 1
x = x - 2
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
x = x * 2
x = x / 3
return x
m = M()
with fx_traceback.preserve_node_meta():
ep = export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.target == torch.ops.aten.add.default:
self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0})
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.mul.default:
self.assertTrue(
node.meta["custom"],
{"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1},
)
if node.target == torch.ops.aten.div.default:
self.assertTrue(node.meta["custom"], {})
def test_dynamic_shapes_serdes_generic(self):
from torch._export.serde.dynamic_shapes import (
_dump_dynamic_shapes,
@ -15898,50 +15777,6 @@ class GraphModule(torch.nn.Module):
]
self.assertEqual(len(shift_op), 1)
def test_export_rnn_variants_with_warning(self):
"""
Test that when exporting RNN, LSTM, and GRU models in non-strict mode, it:
1. Produces expected warnings about tensor attributes being assigned during export
2. Does not leak fake tensors in the model's flat weights
3. Does not produce extra tensor constants in the graph signature
"""
rnn_types = [
(torch.nn.RNN, "RNN"),
(torch.nn.LSTM, "LSTM"),
(torch.nn.GRU, "GRU"),
]
for rnn_class, rnn_name in rnn_types:
with self.subTest(rnn_type=rnn_name):
m = rnn_class(
input_size=2, hidden_size=4, num_layers=1, batch_first=True
)
sample_inputs = (torch.randn(1, 2, 2),)
eager_out = m(*sample_inputs)
# Verify that export produces the expected warning about tensor attributes
with self.assertWarnsRegex(
UserWarning,
r"The tensor attributes self\._flat_weights\[0\], self\._flat_weights\[1\], "
r"self\._flat_weights\[2\], self\._flat_weights\[3\] were assigned during export.*",
):
ep = torch.export.export(m, sample_inputs, strict=False)
ep_out = ep.module()(*sample_inputs)
self.assertEqual(eager_out, ep_out)
# Verify no fake tensor leakage: flat weights should be real tensors
for flat_weight in m._flat_weights:
self.assertTrue(
not isinstance(
flat_weight, torch._subclasses.fake_tensor.FakeTensor
)
)
# Verify no tensor constants in graph signature
self.assertEqual(len(ep.graph_signature.lifted_tensor_constants), 0)
@contextmanager
def distributed_env(self, world_size):
try:

View File

@ -57,6 +57,7 @@ fake_export_failures = {
xfail("nn.functional.grid_sample"),
xfail("to_sparse"),
# cannot xfail as it is passing for cpu-only build
skip("nn.functional.conv2d"),
skip("nn.functional.scaled_dot_product_attention"),
# following are failing due to OptionalDeviceGuard
xfail("__getitem__"),
@ -80,6 +81,7 @@ def _test_export_helper(self, dtype, op):
sample_inputs_itr = op.sample_inputs("cpu", dtype, requires_grad=False)
mode = FakeTensorMode(allow_non_fake_inputs=True)
converter = mode.fake_tensor_converter
# intentionally avoid cuda:0 to flush out some bugs
target_device = "cuda:1"

View File

@ -9,7 +9,6 @@
from contextlib import ExitStack
import torch
import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
from torch._decomp import decomposition_table
@ -762,52 +761,6 @@ class inner_f(torch.nn.Module):
compiled_fn(*dict(model.named_parameters()).values(), inputs).sum().backward()
self.assertIsNotNone(model.linear.weight.grad)
def test_preserve_annotate_simple(self):
"""Test basic linear module with aot_export_joint_with_descriptors"""
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
y = self.linear(x)
return y - 1
inputs = (torch.randn(4, 3),)
for with_export in [True, False]:
with ExitStack() as stack:
model = None
with fx_traceback.preserve_node_meta():
if with_export:
ep = torch.export.export(SimpleLinear(), inputs)
model = ep.module()
else:
model = SimpleLinear()
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs, decompositions=decomposition_table
)
for node in joint_with_descriptors.graph_module.graph.nodes:
if (
node.target
in (
torch.ops.prims.transpose.default,
torch.ops.aten.mm.default,
torch.ops.prims.mul.default,
torch.ops.prims.broadcast_in_dim.default,
torch.ops.prims.add.default,
)
# TODO: add annotation to backward graph nodes
and node.meta.get("partitioner_tag") != "is_backward"
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta.get("custom", {}), {})
if __name__ == "__main__":
run_tests()

View File

@ -8500,6 +8500,7 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
{
"enable_autograd_cache": True,
"strict_autograd_cache": True,
"view_replay_for_aliased_outputs": False,
}
)
@torch._inductor.config.patch("fx_graph_cache", True)

View File

@ -5074,52 +5074,6 @@ class AOTInductorTestsTemplate:
self.check_model(Model(N, K, self.device), example_inputs)
def test_aoti_user_defined_triton_kernel_profiling(self):
if self.device != GPU_TYPE or self.device == "mps":
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
out = torch.zeros_like(x)
add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16)
return out
example_inputs = (
torch.randn(4, 4, device=self.device),
torch.randn(4, 4, device=self.device),
)
with (
config.patch({"cpp.enable_kernel_profile": True}),
torch.profiler.profile(
record_shapes=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
) as prof,
):
self.check_model(Model(), example_inputs)
with common_utils.TemporaryFileName(mode="w+") as fname:
prof.export_chrome_trace(fname)
with open(fname) as f:
import json
j = json.load(f)
op_events = [
e
for e in j["traceEvents"]
if e.get("name", "") == "kernels_.add_kernel_0"
]
self.assertEqual(len(op_events), 1)
self.assertEqual(
op_events[0]["args"].get("Input Args", ""),
["in_ptr0", "in_ptr1", "out_ptr", "n_elements"],
)
def test_aoti_debug_printer_user_defined_triton_kernel(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
@ -7194,37 +7148,6 @@ class AOTInductorTestsTemplate:
for lib in torch_libs:
self.assertTrue(lib not in line)
def test_unbounded_expr_substitutions(self):
class Model(torch.nn.Module):
def forward(self, x, y, a, b):
u0, s0 = a.item(), b.item()
u_max = max(u0, 15)
# construct the equality rule Max(15, u0) == s0 * Max(15, u0)
torch._check(u_max == s0 * u_max)
# size x - [Max(u0, 15), 64]
x = x.expand(u_max, *x.shape).clone()
return x @ y
model = Model()
example_inputs = (
torch.randn((64,), dtype=torch.bfloat16, device=self.device),
torch.randn((64, 16), dtype=torch.bfloat16, device=self.device),
torch.tensor(19, device=self.device),
torch.tensor(1, device=self.device),
)
torch._dynamo.mark_dynamic(example_inputs[-1], 0)
so_path, code = run_and_get_cpp_code(
AOTIRunnerUtil.legacy_compile, model, example_inputs
)
compiled = AOTIRunnerUtil.legacy_load(self.device, so_path)
compiled_outputs = compiled(*example_inputs)
eager_outputs = model(*example_inputs)
torch.testing.assert_close(eager_outputs, compiled_outputs)
class AOTInductorLoggingTest(LoggingTestCase):
@make_logging_test(dynamic=logging.DEBUG)

Some files were not shown because too many files have changed in this diff Show More