Compare commits

..

3 Commits

Author SHA1 Message Date
9e2e034736 Update
[ghstack-poisoned]
2025-11-04 19:59:19 -08:00
fa5d6a24bd Update
[ghstack-poisoned]
2025-11-04 17:10:49 -08:00
e57275b49c Update (base update)
[ghstack-poisoned]
2025-11-04 17:10:49 -08:00
309 changed files with 1879 additions and 6992 deletions

View File

@ -13,4 +13,3 @@ exclude:
- "**/benchmarks/**"
- "**/test_*.py"
- "**/*_test.py"
- "tools/**"

View File

@ -149,7 +149,7 @@ FROM cpu_final as rocm_final
ARG ROCM_VERSION=6.0
ARG PYTORCH_ROCM_ARCH
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
ARG DEVTOOLSET_VERSION=13
ARG DEVTOOLSET_VERSION=11
ENV LDFLAGS="-Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64 -Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib"
# Somewhere in ROCm stack, we still use non-existing /opt/rocm/hip path,
# below workaround helps avoid error

View File

@ -1,11 +1,15 @@
sphinx==7.2.6
sphinx==5.3.0
#Description: This is used to generate PyTorch docs
#Pinned versions: 7.2.6
#Pinned versions: 5.3.0
pytorch_sphinx_theme2==0.2.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.2.0
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.36.0
breathe==4.34.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.36.0
#Pinned versions: 4.34.0
exhale==0.3.7
exhale==0.2.3
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.3.7
#Pinned versions: 0.2.3
docutils==0.20
docutils==0.16
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.20
#Pinned versions: 0.16
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -52,13 +56,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==1.3.0
myst-nb==0.17.2
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 1.3.0
#Pinned versions: 0.17.2
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.6.1
sphinx-design==0.4.0
sphinxcontrib-mermaid==1.0.0
myst-parser==4.0.1
myst-parser==0.18.1

View File

@ -89,41 +89,23 @@ if [ "$is_main_doc" = true ]; then
make coverage
# Now we have the coverage report, we need to make sure it is empty.
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
# showing the undocumented count in the third column.
# Example: | TOTAL | 99.83% | 2 |
# Count the number of lines in the file and turn that number into a variable
# $lines. The `cut -f1 ...` is to only parse the number, not the filename
# Skip the report header by subtracting 2: the header will be output even if
# there are no undocumented items.
#
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
# be documented then removed from there.
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
# The table format is: | Module | Coverage | Undocumented |
# Extract the third column (undocumented count) from the TOTAL row
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
undocumented=$((lines - 2))
if [ $undocumented -lt 0 ]; then
echo coverage output not found
exit 1
elif [ "$undocumented" -gt 0 ]; then
set +x # Disable command echoing for cleaner output
echo ""
echo "====================="
echo "UNDOCUMENTED OBJECTS:"
echo "====================="
echo ""
# Find the line number of the TOTAL row and print only what comes after it
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
if [ -n "$total_line" ]; then
# Print only the detailed list (skip the statistics table)
tail -n +$((total_line + 2)) build/coverage/python.txt
else
# Fallback to showing entire file if TOTAL line not found
cat build/coverage/python.txt
fi
echo ""
elif [ $undocumented -gt 0 ]; then
echo undocumented objects found:
cat build/coverage/python.txt
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'"
set -x # Re-enable command echoing
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1
fi
else

View File

@ -337,7 +337,7 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}
@ -1653,7 +1653,7 @@ test_operator_microbenchmark() {
cd "${TEST_DIR}"/benchmarks/operator_benchmark
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do
for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
--benchmark-name "PyTorch operator microbenchmark" --use-compile

View File

@ -28,7 +28,7 @@ CUDA_ARCHES_FULL_VERSION = {
"12.6": "12.6.3",
"12.8": "12.8.1",
"12.9": "12.9.1",
"13.0": "13.0.0",
"13.0": "13.0.2",
}
CUDA_ARCHES_CUDNN_VERSION = {
"12.6": "9",

View File

@ -8,7 +8,6 @@ on:
- docker.Makefile
- .github/workflows/docker-release.yml
- .github/scripts/generate_docker_release_matrix.py
- .github/scripts/generate_binary_build_matrix.py
push:
branches:
- nightly

View File

@ -1,10 +1,9 @@
name: inductor-rocm
on:
schedule:
- cron: 0 * * * *
push:
branches:
- main
- release/*
tags:
- ciflow/inductor-rocm/*

View File

@ -115,10 +115,10 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
]}
secrets: inherit

View File

@ -84,13 +84,13 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
]}
build-additional-packages: "vision audio torchao"

View File

@ -76,12 +76,11 @@ jobs:
# NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes
# fails to find types when it should
# NOTE: We should be able to disable this and consolidate with Pyrefly
lintrunner-pyrefly:
lintrunner-mypy:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
name: lintrunner-pyrefly-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
name: lintrunner-mypy-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
needs: [get-label-type, get-changed-files]
# Only run if there are changed files relevant to pyrefly
# Only run if there are changed files relevant to mypy
if: |
github.repository_owner == 'pytorch' && (
needs.get-changed-files.outputs.changed-files == '*' ||
@ -99,8 +98,8 @@ jobs:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
echo "Running pyrefly"
ADDITIONAL_LINTRUNNER_ARGS="--take PYREFLY --all-files" .github/scripts/lintrunner.sh
echo "Running mypy"
ADDITIONAL_LINTRUNNER_ARGS="--take MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
lintrunner-noclang:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
@ -119,9 +118,9 @@ jobs:
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
echo "Running all other linters"
if [ "$CHANGED_FILES" = '*' ]; then
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY --all-files" .github/scripts/lintrunner.sh
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
else
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
fi
quick-checks:

View File

@ -41,7 +41,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
build-environment: linux-jammy-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
secrets: inherit

View File

@ -66,10 +66,10 @@ jobs:
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
]}
secrets: inherit
@ -167,8 +167,8 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
]}
secrets: inherit

View File

@ -3,13 +3,13 @@ name: rocm
on:
push:
branches:
- main
- release/*
tags:
- ciflow/rocm/*
workflow_dispatch:
schedule:
- cron: 29 8 * * * # about 1:29am PDT
- cron: 0 * * * *
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}

View File

@ -204,7 +204,6 @@ jobs:
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" },
]}
secrets: inherit
@ -222,7 +221,7 @@ jobs:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
secrets: inherit
inductor-build:

1
.gitignore vendored
View File

@ -127,7 +127,6 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

View File

@ -121,6 +121,94 @@ command = [
]
is_formatter = true
[[linter]]
code = 'MYPY'
include_patterns = [
'setup.py',
'functorch/dim/**/*.py',
'torch/**/*.py',
'torch/**/*.pyi',
'caffe2/**/*.py',
'caffe2/**/*.pyi',
'test/test_bundled_images.py',
'test/test_bundled_inputs.py',
'test/test_complex.py',
'test/test_datapipe.py',
'test/test_futures.py',
'test/test_numpy_interop.py',
'test/test_torch.py',
'test/test_type_hints.py',
'test/test_type_info.py',
'test/test_utils.py',
]
exclude_patterns = [
'**/fb/**',
]
command = [
'python3',
'tools/linter/adapters/mypy_linter.py',
'--config=mypy.ini',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'mypy==1.16.0',
'sympy==1.13.3',
'types-requests==2.27.25',
'types-pyyaml==6.0.2',
'types-tabulate==0.8.8',
'types-protobuf==5.29.1.20250403',
'types-setuptools==79.0.0.20250422',
'types-jinja2==2.11.9',
'types-colorama==0.4.6',
'filelock==3.18.0',
'junitparser==2.1.1',
'rich==14.1.0',
'pyyaml==6.0.2',
'optree==0.13.0',
'dataclasses-json==0.6.7',
'pandas==2.2.3',
]
[[linter]]
code = 'MYPYSTRICT'
include_patterns = [
'.github/**/*.py',
'benchmarks/instruction_counts/**/*.py',
'tools/**/*.py',
'torchgen/**/*.py',
'torch/utils/_pytree.py',
'torch/utils/_cxx_pytree.py',
'torch/utils/benchmark/utils/common.py',
'torch/utils/benchmark/utils/timer.py',
'torch/utils/benchmark/utils/valgrind_wrapper/**/*.py',
]
exclude_patterns = [
# (linbinyu) copied from internal repo
'**/fb/**',
'tools/code_analyzer/gen_operators_yaml.py',
'tools/dynamo/verify_dynamo.py',
'tools/gen_vulkan_spv.py',
'tools/test/gen_operators_yaml_test.py',
'tools/test/gen_oplist_test.py',
'tools/test/test_selective_build.py',
'tools/experimental/torchfuzz/**',
]
command = [
'python3',
'tools/linter/adapters/mypy_linter.py',
'--config=mypy-strict.ini',
'--code=MYPYSTRICT',
'--',
'@{{PATHSFILE}}'
]
[[linter]]
code = 'PYREFLY'
@ -142,7 +230,6 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'pyrefly==0.36.2',
@ -211,6 +298,7 @@ exclude_patterns = [
'**/*pb.h',
'**/*inl.h',
'aten/src/ATen/cpu/FlushDenormal.cpp',
'aten/src/ATen/cpu/Utils.cpp',
'aten/src/ATen/cpu/vml.h',
'aten/src/ATen/CPUFixedAllocator.h',
'aten/src/ATen/Parallel*.h',
@ -229,6 +317,8 @@ exclude_patterns = [
'c10/util/win32-headers.h',
'c10/test/**/*.h',
'third_party/**/*',
'torch/csrc/api/include/torch/nn/modules/common.h',
'torch/csrc/api/include/torch/linalg.h',
'torch/csrc/autograd/generated/**',
'torch/csrc/distributed/**/*.cu',
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
@ -240,6 +330,7 @@ exclude_patterns = [
'torch/csrc/utils/generated_serialization_types.h',
'torch/csrc/utils/pythoncapi_compat.h',
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
'aten/src/ATen/ExpandBase.h',
]
init_command = [
'python3',

View File

@ -1,7 +1,7 @@
# Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using PyTorch Securely**](#using-pytorch-securely)
- [**Using Pytorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs)
@ -10,28 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat
## Using PyTorch Securely
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
## Using Pytorch Securely
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
### Using distributed features

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")

View File

@ -23,6 +23,8 @@ C10_DIAGNOSTIC_POP()
#endif
namespace at {
namespace {
/*
These const variables defined the fp32 precisions for different backend
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
@ -39,6 +41,16 @@ namespace at {
->rnn
*/
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
TORCH_WARN_ONCE(
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
);
}
} // namespace
Float32Backend str2backend(const std::string& name) {
if (name == "generic")
return Float32Backend::GENERIC;
@ -194,6 +206,7 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
} else {
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
}
warn_deprecated_fp32_precision_api();
return allow_tf32_cudnn;
}
@ -201,6 +214,7 @@ void Context::setAllowTF32CuDNN(bool b) {
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
allow_tf32_cudnn = b;
warn_deprecated_fp32_precision_api();
}
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
@ -311,6 +325,7 @@ bool Context::allowTF32CuBLAS() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
"We suggest only using the new API to set the TF32 flag. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return allow_tf32_new;
}
@ -334,6 +349,7 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
"We suggest only using the new API for matmul precision. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return float32_matmul_precision;
}
@ -361,6 +377,7 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
void Context::setFloat32MatmulPrecision(const std::string &s) {
auto match = [this](const std::string & s_) {
warn_deprecated_fp32_precision_api();
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;

View File

@ -72,16 +72,10 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) {
m.impl("random_", unsupportedRandomOp_<Tensor&, std::optional<Generator>>);
m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("rand_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randn_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.Tensor", unsupportedRandomOp<const Tensor&, const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.generator", unsupportedRandomOp<const Tensor&, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.Tensor_generator", unsupportedRandomOp<const Tensor&, const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.low_generator_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);

View File

@ -191,37 +191,22 @@ inline void convert(const at::Half* src, bool* dst, int64_t n) {
}
#endif
template <typename to_type>
inline void convertFromBf16Impl(
const c10::BFloat16* __restrict src,
to_type* __restrict dst,
int64_t n) {
const uint16_t* srcPtr = reinterpret_cast<const uint16_t*>(src);
uint64_t len = static_cast<uint64_t>(n);
for (uint64_t i = 0; i < len; i++) {
uint32_t tmp = static_cast<uint32_t>(srcPtr[i]) << 16;
float tmpF;
__builtin_memcpy(&tmpF, &tmp, sizeof(float));
dst[i] = static_cast<to_type>(tmpF);
}
}
#define CONVERT_FROM_BF16_TEMPLATE(to_type) \
template <> \
inline void convert(const c10::BFloat16* src, to_type* dst, int64_t n) { \
return convertFromBf16Impl<to_type>(src, dst, n); \
}
CONVERT_FROM_BF16_TEMPLATE(uint8_t)
CONVERT_FROM_BF16_TEMPLATE(int8_t)
CONVERT_FROM_BF16_TEMPLATE(int16_t)
CONVERT_FROM_BF16_TEMPLATE(int32_t)
CONVERT_FROM_BF16_TEMPLATE(int64_t)
CONVERT_FROM_BF16_TEMPLATE(float)
CONVERT_FROM_BF16_TEMPLATE(double)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
CONVERT_FROM_BF16_TEMPLATE(float16_t)
#endif
#ifdef __ARM_FEATURE_BF16
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
CONVERT_TEMPLATE(bfloat16_t, int8_t)
CONVERT_TEMPLATE(bfloat16_t, int16_t)
CONVERT_TEMPLATE(bfloat16_t, int32_t)
CONVERT_TEMPLATE(bfloat16_t, int64_t)
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
CONVERT_TEMPLATE(bfloat16_t, float)
CONVERT_TEMPLATE(bfloat16_t, double)
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
CONVERT_TEMPLATE(int8_t, bfloat16_t)
CONVERT_TEMPLATE(int16_t, bfloat16_t)
CONVERT_TEMPLATE(int32_t, bfloat16_t)
CONVERT_TEMPLATE(int64_t, bfloat16_t)
CONVERT_TEMPLATE(float, bfloat16_t)
CONVERT_TEMPLATE(double, bfloat16_t)
inline void convertBoolToBfloat16Impl(
const bool* __restrict src,
@ -262,6 +247,8 @@ inline void convert(const c10::BFloat16* src, bool* dst, int64_t n) {
#endif
#endif
template <typename src_t>
struct VecConvert<
float,

View File

@ -1,6 +1,6 @@
#include <ATen/cuda/CUDAGreenContext.h>
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <stdexcept>
#include <vector>

View File

@ -9,8 +9,8 @@
#include <c10/core/Allocator.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <c10/util/python_stub.h>
#include <ATen/detail/AcceleratorHooksInterface.h>
#include <string>
namespace at {
@ -26,7 +26,8 @@ constexpr const char* MTIA_HELP =
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
// this fails the implementation if MTIAHooks functions are called, but
// MTIA backend is not present.
#define FAIL_MTIAHOOKS_FUNC(func) TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
#define FAIL_MTIAHOOKS_FUNC(func) \
TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
~MTIAHooksInterface() override = default;
@ -91,7 +92,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
}
virtual void setCurrentStream(const c10::Stream& /*stream*/) const {
virtual void setCurrentStream(const c10::Stream& /*stream*/ ) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
@ -123,9 +124,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void recordMemoryHistory(const std::optional<std::string>& /*enabled*/,
const std::string& /*stacks*/,
size_t /*max_entries*/) const {
virtual void recordMemoryHistory(
const std::optional<std::string>& /*enabled*/,
const std::string& /*stacks*/,
size_t /*max_entries*/) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
@ -156,10 +159,6 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
return -1;
}
virtual void mtiagraphDestroy(int64_t handle) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
FAIL_MTIAHOOKS_FUNC(__func__);
}
@ -188,7 +187,8 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
struct TORCH_API MTIAHooksArgs {};
TORCH_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
#define REGISTER_MTIA_HOOKS(clsname) C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
#define REGISTER_MTIA_HOOKS(clsname) \
C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
namespace detail {
TORCH_API const MTIAHooksInterface& getMTIAHooks();

View File

@ -2917,7 +2917,9 @@ static Tensor& linalg_eig_make_complex_eigenvectors(Tensor& complex_vectors, con
DEFINE_DISPATCH(linalg_eig_stub);
static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Tensor& values, Tensor& vectors, Tensor& infos, bool compute_eigenvectors) {
auto options = input.options();
// MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
// therefore we create all intermediate tensors on CPU
auto options = input.options().device(at::kCPU);
// These internal asserts make explicit the assumptions in the implementation
// Error check with the actual error messages are done on the higher level of the hierarchy of calls
@ -2926,13 +2928,16 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
// for real-valued 'input', eigenvalues can be real-valued or complex-valued
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == values.scalar_type()) || (input.scalar_type() == values.scalar_type()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
// for real-valued 'input', eigenvectors can be real-valued or complex-valued
if (compute_eigenvectors) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == vectors.scalar_type()) || (input.scalar_type() == vectors.scalar_type()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.device() == at::kCPU);
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.scalar_type() == at::kInt);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == at::kCPU);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.numel() == std::max<int64_t>(1, batchCount(input)));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_contiguous());
@ -2981,7 +2986,15 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
}
}
linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
// MAGMA uses a hybrid CPU-GPU algorithm that performs well only for large matrices
// See: https://github.com/pytorch/pytorch/pull/52491#issuecomment-795685687
// Here we call CPU path for matrices smaller than 2048x2048
// that should be in general significantly faster than calling MAGMA
if (input.size(-1) <= 2048) {
linalg_eig_stub(at::kCPU, real_imag_values, maybe_complex_vectors, infos, input.to(kCPU), compute_eigenvectors);
} else {
linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
}
// if input is not complex we need to do some post-processing
if (!input.is_complex()) {
@ -3006,14 +3019,7 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
}
if (compute_eigenvectors) {
if (vectors.is_complex()) {
// We move to the CPU because linalg_eig_make_complex_eigenvectors requires it.
// Performance note: this function could be implemented via a TensorIterator,
// which would avoid an explicit host-device synchronization.
auto vectors_cpu = vectors.cpu();
auto values_cpu = values.cpu();
auto maybe_complex_vectors_cpu = maybe_complex_vectors.cpu();
vectors_cpu = linalg_eig_make_complex_eigenvectors(vectors_cpu, values_cpu, maybe_complex_vectors_cpu);
vectors.copy_(vectors_cpu);
vectors = linalg_eig_make_complex_eigenvectors(vectors, values, maybe_complex_vectors);
} else {
TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvectors is non-zero, can't safely cast eigenvectors to non-complex dtype.")
}
@ -3033,7 +3039,8 @@ std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values,
checkSameDevice("torch.linalg.eig", values, input, "eigenvalues");
checkSameDevice("torch.linalg.eig", vectors, input, "eigenvectors");
auto options = input.options();
// MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
auto options = input.options().device(at::kCPU);
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
// if result is not empty and not in batched column major format we have to allocate a temporary tensor
@ -3122,7 +3129,8 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
checkLinalgCompatibleDtype("torch.linalg.eigvals", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
checkSameDevice("torch.linalg.eigvals", values, input, "eigenvalues");
auto options = input.options();
// MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
auto options = input.options().device(at::kCPU);
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
bool values_expected_type = (values.scalar_type() == toComplexType(input.scalar_type()));
@ -3151,7 +3159,6 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
}
Tensor vectors;
vectors = at::empty({0}, input.options());
if (values_tmp_needed) {
Tensor values_tmp = at::empty({0}, options.dtype(values_type));
std::tie(values_tmp, std::ignore) = linalg_eig_out_info(input, values_tmp, vectors, infos, /*compute_eigenvectors=*/false);

View File

@ -11,7 +11,6 @@
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/TensorOperators.h>
#include <ATen/TracerMode.h>
#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/UnaryOps.h>
#include <c10/core/ScalarType.h>
@ -1090,7 +1089,6 @@ Tensor& rand_out(
Tensor rand_like(
const Tensor& self,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
@ -1102,24 +1100,7 @@ Tensor rand_like(
pin_memory);
auto result = at::empty_like(self, options, optional_memory_format);
return result.uniform_(0, 1, std::move(generator));
}
Tensor rand_like(
const Tensor& self,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
return native::rand_like(
self,
static_cast<std::optional<Generator>>(std::nullopt),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
return result.uniform_(0, 1, std::nullopt);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1216,9 +1197,7 @@ Tensor& randint_out(
Tensor randint_like(
const Tensor& self,
int64_t low,
int64_t high,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
@ -1230,7 +1209,29 @@ Tensor randint_like(
pin_memory);
auto result = at::empty_like(self, options, optional_memory_format);
return result.random_(low, high, std::move(generator));
return result.random_(0, high, std::nullopt);
}
Tensor randint_like(
const Tensor& self,
const Tensor& high,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(
high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(),
"high must be a scalar tensor and on CPU");
int64_t high_scalar = high.item<int64_t>();
return at::native::randint_like(
self,
high_scalar,
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
Tensor randint_like(
@ -1242,108 +1243,13 @@ Tensor randint_like(
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
return native::randint_like(
self,
low,
high,
static_cast<std::optional<Generator>>(std::nullopt),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
Tensor randint_like(
const Tensor& self,
int64_t high,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
// See [Note: hacky wrapper removal for TensorOptions]
return native::randint_like(
self,
0,
high,
static_cast<std::optional<Generator>>(std::nullopt),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
TensorOptions options =
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
pin_memory);
Tensor randint_like(
const Tensor& self,
int64_t high,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
// See [Note: hacky wrapper removal for TensorOptions]
return native::randint_like(
self,
0,
high,
generator,
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
Tensor randint_like(
const Tensor& self,
const Tensor& high,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(
high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(),
"high must be a scalar tensor and on CPU");
int64_t high_scalar = high.item<int64_t>();
return at::native::randint_like(
self,
0,
high_scalar,
static_cast<std::optional<Generator>>(std::nullopt),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
Tensor randint_like(
const Tensor& self,
const Tensor& high,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(
high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(),
"high must be a scalar tensor and on CPU");
int64_t high_scalar = high.item<int64_t>();
return at::native::randint_like(
self,
0,
high_scalar,
generator,
dtype,
layout,
device,
pin_memory,
optional_memory_format);
auto result = at::empty_like(self, options, optional_memory_format);
return result.random_(low, high, std::nullopt);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -1421,7 +1327,6 @@ Tensor& normal_out(
Tensor randn_like(
const Tensor& self,
std::optional<Generator> generator,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
@ -1433,24 +1338,7 @@ Tensor randn_like(
pin_memory);
auto result = at::empty_like(self, options, optional_memory_format);
return result.normal_(0, 1, std::move(generator));
}
Tensor randn_like(
const Tensor& self,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
return native::randn_like(
self,
static_cast<std::optional<Generator>>(std::nullopt),
dtype,
layout,
device,
pin_memory,
optional_memory_format);
return result.normal_(0, 1, std::nullopt);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randperm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -77,7 +77,9 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
int64_t grain_size = at::internal::GRAIN_SIZE;
auto loop = [strides_in, requires_neg](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) {
std::array<char*, 2> data;
std::copy_n(base, 2, data.data());
const int64_t *outer_strides = &strides[2];
for ([[maybe_unused]] const auto it : c10::irange(size1)) {
@ -144,7 +146,9 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
int64_t grain_size = at::internal::GRAIN_SIZE;
auto loop = [strides_in, requires_neg](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) {
std::array<char*, 2> data;
std::copy_n(base, 2, data.data());
const int64_t *outer_strides = &strides[2];
for ([[maybe_unused]] const auto it : c10::irange(size1)) {

View File

@ -493,33 +493,40 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
for ([[maybe_unused]] const auto j : c10::irange(size1)) {
// vectorized loop with negative stride for output
char** C10_RESTRICT data_ = data_arr.data();
int64_t n = size0;
char* C10_RESTRICT data[ntensors];
for (const auto arg : c10::irange(ntensors)) {
data[arg] = data_[arg];
}
int64_t i = 0;
// data_arr[0] unaligned pre-pass
// data[0] unaligned pre-pass
int64_t offset = (j * n + (n - i - Vec::size())) % 32;
offset = (offset >= n) ? n : offset;
for (; i < offset; i++) {
scalar_t* out_ptr = (scalar_t*)(data_arr[0] - i * stride);
*out_ptr = c10::load((scalar_t *)(data_arr[1] + i * stride));
scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
*out_ptr = c10::load((scalar_t *)(data[1] + i * stride));
}
// Empirically found that it is faster to process 3 data items together vs 2 or 4
for (; i <= n - 3 * Vec::size(); i += 3 * Vec::size()) {
auto out1 = Vec::loadu(data_arr[1] + i * stride);
auto out2 = Vec::loadu(data_arr[1] + (i + Vec::size()) * stride);
auto out3 = Vec::loadu(data_arr[1] + (i + 2 * Vec::size()) * stride);
auto out1 = Vec::loadu(data[1] + i * stride);
auto out2 = Vec::loadu(data[1] + (i + Vec::size()) * stride);
auto out3 = Vec::loadu(data[1] + (i + 2 * Vec::size()) * stride);
// flip the vector: 1234 -> 4321
out1 = flip(out1);
out2 = flip(out2);
out3 = flip(out3);
out1.store(data_arr[0] - (i + Vec::size() - 1) * stride);
out2.store(data_arr[0] - (i + 2 * Vec::size() - 1) * stride);
out3.store(data_arr[0] - (i + 3 * Vec::size() - 1) * stride);
out1.store(data[0] - (i + Vec::size() - 1) * stride);
out2.store(data[0] - (i + 2 * Vec::size() - 1) * stride);
out3.store(data[0] - (i + 3 * Vec::size() - 1) * stride);
}
if (i < n) {
for (; i < n; i++) {
scalar_t* out_ptr = (scalar_t*)(data_arr[0] - i * stride);
*out_ptr = c10::load((scalar_t *)(data_arr[1] + i * stride));
scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
*out_ptr = c10::load((scalar_t *)(data[1] + i * stride));
}
}
@ -553,8 +560,15 @@ void cpu_vflip_memcpy(at::TensorIterator& iter) {
const int64_t stride = strides[0];
for ([[maybe_unused]] const auto j : c10::irange(size1)) {
char** C10_RESTRICT data_ = data_arr.data();
int64_t n = size0;
memcpy(data_arr[0], data_arr[1], n * stride);
char* C10_RESTRICT data[ntensors];
for (const auto arg : c10::irange(ntensors)) {
data[arg] = data_[arg];
}
memcpy(data[0], data[1], n * stride);
// advance:
for (const auto arg : c10::irange(data_arr.size())) {

View File

@ -92,8 +92,7 @@ void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, double beta) {
ScalarType dtype = iter.dtype(0);
if (at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "smooth_l1_backward_cpu_out", [&]() {
if (dtype == kBFloat16) {
auto norm_val = norm.to<float>();
float beta_val(beta);
auto norm_val_vec = Vectorized<float>(norm_val);
@ -102,9 +101,9 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
const auto zero_vec = Vectorized<float>(0);
const auto pos_1_vec = Vectorized<float>(1);
cpu_kernel_vec(iter,
[=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
[=](BFloat16 input, BFloat16 target, BFloat16 grad_output) -> BFloat16 {
const auto x = float(input) - float(target);
if (x <= -beta) {
if (x <= -beta){
return -norm_val * float(grad_output);
}else if (x >= beta){
return norm_val * float(grad_output);
@ -113,14 +112,14 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
}
},
[norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
Vectorized<scalar_t> input, Vectorized<scalar_t> target, Vectorized<scalar_t> grad_output) -> Vectorized<scalar_t> {
Vectorized<BFloat16> input, Vectorized<BFloat16> target, Vectorized<BFloat16> grad_output) -> Vectorized<BFloat16> {
// using two blendv calls to simulate the 3 cases
// 1 if x >= beta
// -1 if x <= -beta
// x / beta if |x| < beta
auto [input0, input1] = convert_to_float(input);
auto [target0, target1] = convert_to_float(target);
auto [grad_output0, grad_output1] = convert_to_float(grad_output);
auto [input0, input1] = convert_bfloat16_float(input);
auto [target0, target1] = convert_bfloat16_float(target);
auto [grad_output0, grad_output1] = convert_bfloat16_float(grad_output);
auto x = input0 - target0;
auto pos_or_neg_1_vec = Vectorized<float>::blendv(
neg_1_vec, pos_1_vec, x > zero_vec);
@ -136,10 +135,9 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
output = Vectorized<float>::blendv(
x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
input1 = norm_val_vec * output * grad_output1;
return convert_from_float<scalar_t>(input0, input1);
return convert_float_bfloat16(input0, input1);
}
);
});
} else {
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
auto norm_val = norm.to<scalar_t>();

View File

@ -205,8 +205,8 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
// and the leading stride is at least max(1, other dim length), so we might
// end up with contiguous cols but not rows (i.e. holes between different rows)
// and vice versa.
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32
&& mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
&& (
// filter by dtype
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||

View File

@ -54,6 +54,7 @@ namespace {
using DtypeScale = float;
using DtypeAccum = float;
using DtypeEpilogue = float;
using DtypeOutput = cutlass::bfloat16_t;
using Multiply = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
@ -67,6 +68,12 @@ using Add = cutlass::epilogue::fusion::Sm90Compute<
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using Cast = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity,
DtypeOutput,
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
template <bool LargeTile, bool FastAccum>
struct Schedule;
@ -113,8 +120,7 @@ template <
typename FastAccum,
typename DtypeA,
typename DtypeB,
typename DtypeBias,
typename DtypeOutput>
typename DtypeBias>
void f8f8bf16_rowwise_impl(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
@ -175,11 +181,6 @@ void f8f8bf16_rowwise_impl(
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
using Cast = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity,
DtypeOutput,
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
cutlass::epilogue::fusion::Sm90EVT<
@ -312,8 +313,7 @@ template <
typename FastAccum,
typename DtypeA,
typename DtypeB,
typename DtypeBias,
typename DtypeOutput>
typename DtypeBias>
void f8f8bf16_rowwise_impl_sm100_sm120(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
@ -372,11 +372,6 @@ void f8f8bf16_rowwise_impl_sm100_sm120(
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
using Cast = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity,
DtypeOutput,
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
cutlass::epilogue::fusion::Sm90EVT<
@ -503,8 +498,7 @@ template <
typename FastAccum,
typename DtypeA,
typename DtypeB,
typename DtypeBias,
typename DtypeOutput>
typename DtypeBias>
void f8f8bf16_rowwise_impl_sm89(
at::Tensor XQ, // FP8
at::Tensor WQ, // FP8
@ -771,8 +765,7 @@ template <
typename FastAccum,
typename DtypeA,
typename DtypeB,
typename DtypeBias,
typename DtypeOutput>
typename DtypeBias>
void handle_transposition(
at::Tensor XQ,
at::Tensor WQ,
@ -789,8 +782,7 @@ void handle_transposition(
FastAccum,
DtypeA,
DtypeB,
DtypeBias,
DtypeOutput>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
} else {
dispatch_fp8_rowwise_kernel_on_tile_size<
ClusterShape,
@ -799,8 +791,7 @@ void handle_transposition(
FastAccum,
DtypeB,
DtypeA,
DtypeBias,
DtypeOutput>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t(), swizzle);
DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t(), swizzle);
}
}
@ -1036,19 +1027,11 @@ void dispatch_fp8_rowwise_kernel_on_bias_dtype(
at::Tensor out) {
if (bias.has_value() && bias->dtype() == at::kBFloat16) {
dispatch_fp8_rowwise_kernel_on_input_dtypes<
cutlass::bfloat16_t,
cutlass::bfloat16_t>
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
} else if (bias.has_value() && bias->dtype() == at::kHalf){
TORCH_CHECK(out.dtype() == at::kHalf, "Output should be Float16 when bias is Float16");
dispatch_fp8_rowwise_kernel_on_input_dtypes<
cutlass::half_t,
cutlass::half_t>
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
} else {
dispatch_fp8_rowwise_kernel_on_input_dtypes<
float,
cutlass::bfloat16_t>
float>
//Types...>
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
}
@ -1090,14 +1073,14 @@ void check_inputs(
if (bias.has_value()) {
TORCH_CHECK(bias->device() == b.device());
TORCH_CHECK(bias->dtype() == at::kFloat || bias->dtype() == at::kBFloat16 || bias->dtype() == at::kHalf);
TORCH_CHECK(bias->dtype() == at::kFloat || bias->dtype() == at::kBFloat16);
TORCH_CHECK(bias->dim() == 1);
TORCH_CHECK(bias->size(0) == b.size(1));
TORCH_CHECK(bias->stride(0) == 1);
}
TORCH_CHECK(out.device() == a.device());
TORCH_CHECK(out.dtype() == at::kBFloat16 || out.dtype() == at::kHalf);
TORCH_CHECK(out.dtype() == at::kBFloat16);
TORCH_CHECK(out.dim() == 2);
TORCH_CHECK(out.size(0) == a.size(0));
TORCH_CHECK(out.size(1) == b.size(1));

View File

@ -59,24 +59,6 @@
// forward declare
class cublasCommonArgs;
#ifndef _WIN32
namespace fbgemm_gpu {
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
// To update supported ops means a submodule bump, which is.. painful. Instead, we
// can simply forward-declare the methods we want to use.. Works at least as a short-term
// thing, but should still be fixed somewhere/somehow.
at::Tensor f4f4bf16(
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
bool use_mx);
} // namespace fbgemm_gpu
#endif
using at::blas::ScalingType;
using at::blas::SwizzleType;
@ -609,7 +591,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) {
TORCH_CHECK_VALUE(out.dtype() == kBFloat16 || out.dtype() == kHalf, "Only bf16 and fp16 high precision output types are supported for row-wise scaling.");
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
return _scaled_rowwise_rowwise(
mat1,
mat2,
@ -754,7 +736,7 @@ _scaled_rowwise_rowwise(
if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) {
TORCH_CHECK_VALUE(out.dtype() == kBFloat16 || out.dtype() == kHalf, "Only bf16 and fp16 high precision output types are supported for row-wise scaling.");
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
mat_a,
mat_b,
@ -785,6 +767,33 @@ _scaled_rowwise_rowwise(
return out;
}
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
// and strides become somewhat meaningless
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
if (scale_type == ScalingType::BlockWise1x128) {
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
} else if (scale_type == ScalingType::BlockWise128x128) {
TORCH_CHECK_VALUE(check_size_stride(
scale,
0,
ceil_div<int64_t>(t.size(0), 128),
ceil_div<int64_t>(t.size(1), 128)),
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
TORCH_CHECK(check_size_stride(
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
}
}
void
_check_deepseek_support() {
#ifndef USE_ROCM
@ -797,7 +806,7 @@ _check_deepseek_support() {
}
// Only in cublasLt >= 12.9
TORCH_CHECK_NOT_IMPLEMENTED(
CUBLAS_VERSION >= 120900 && cublasLtGetVersion() >= 120900,
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
);
#endif
@ -814,61 +823,23 @@ _scaled_block1x128_block1x128(
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
// As: [M x K // 128], stride: [1, M]
// Bs: [N x K // 128], stride: [1, N]
// CUDA: Only Hopper GPUs
_check_deepseek_support();
// check types
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) &&
isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
);
const int64_t M = mat_a.sizes()[0];
const int64_t K = mat_a.sizes()[1];
const int64_t N = mat_b.sizes()[1];
// scale_a shape
TORCH_CHECK_VALUE(
scale_a.size(0) == M &&
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
scale_a.scalar_type() == kFloat,
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
);
// scale_a stride
TORCH_CHECK_VALUE(
scale_a.stride(0) == 1 &&
(
scale_a.stride(1) == M ||
(scale_a.size(1) == 1 && scale_b.stride(1) == 1)
),
"scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides()
);
// scale_b shape
TORCH_CHECK_VALUE(
scale_b.size(0) == N &&
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
scale_b.scalar_type() == kFloat,
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
);
// scale_b stride
TORCH_CHECK_VALUE(
scale_b.stride(0) == 1 &&
(
scale_b.stride(1) == N ||
(
scale_b.size(1) == 1 &&
scale_b.stride(1) == 1
)
),
"scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides()
);
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
@ -890,65 +861,24 @@ _scaled_block128x128_block1x128(
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
// A: [M, K], B: [K, N] are FP8, scales are fp32
// As: [round_up(K // 128, 4), M // 128], stride: [M // 128, 1]
// Bs: [N x K // 128], stride: [1, N]
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) &&
isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
);
const int64_t M = mat_a.sizes()[0];
const int64_t K = mat_a.sizes()[1];
const int64_t N = mat_b.sizes()[1];
// scale_a shape
TORCH_CHECK_VALUE(
scale_a.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
scale_a.size(1) == ceil_div<int64_t>(M, 128) &&
scale_a.scalar_type() == kFloat,
"scale_a must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ",
ceil_div<int64_t>(M, 128), " Float elements, got ", scale_a.sizes()
);
// scale_a stride
TORCH_CHECK_VALUE(
scale_a.stride(0) == 1 &&
(
scale_a.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
(
scale_a.size(1) == 1 &&
scale_a.stride(1) == 1
)
),
"scale_a must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
);
// scale_b shape
TORCH_CHECK_VALUE(
scale_b.size(0) == N &&
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
scale_b.scalar_type() == kFloat,
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
);
// scale_b stride
TORCH_CHECK_VALUE(
scale_b.stride(0) == 1 &&
(
scale_b.stride(1) == N ||
(
scale_b.size(1) == 1 &&
scale_b.stride(1) == 1
)
),
"scale_b must have strides (1, ", N, "); got ", scale_b.strides()
);
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
auto scaling_choice_a = ScalingType::BlockWise128x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
@ -970,62 +900,24 @@ _scaled_block1x128_block128x128(
Tensor& out) {
#ifndef USE_ROCM
// Restrictions:
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
// CUDA: Only Hopper GPUs
_check_deepseek_support();
// A: [M, K], B: [K, N] are FP8, scales are fp32
// As: [M x K // 128], stride: [1, M]
// Bs: [round_up(K // 128, 4) x N // 128], stride: [1, N // 128]
TORCH_CHECK_VALUE(
isFloat8Type(mat_a.scalar_type()) &&
isFloat8Type(mat_b.scalar_type()),
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
);
int64_t M = mat_a.size(0);
int64_t K = mat_a.size(1);
int64_t N = mat_b.size(1);
// scale_a shape
TORCH_CHECK_VALUE(
scale_a.size(0) == M &&
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
scale_a.scalar_type() == kFloat,
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
);
// scale_a stride
TORCH_CHECK_VALUE(
scale_a.stride(0) == 1 &&
(
scale_a.stride(1) == M ||
(
scale_a.size(1) == 1 &&
scale_a.stride(1) == 1
)
),
"scale_a must have strides (1, ", M, "); got ", scale_b.strides()
);
// scale_b shape
TORCH_CHECK_VALUE(
scale_b.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
scale_b.size(1) == ceil_div<int64_t>(N, 128) &&
scale_b.scalar_type() == kFloat,
"scale_b must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ", ceil_div<int64_t>(N, 128), " Float elements, got ", scale_b.sizes()
);
// scale_b stride
TORCH_CHECK_VALUE(
scale_b.stride(0) == 1 &&
(
scale_b.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
(
scale_b.size(1) == 1 &&
scale_b.stride(1) == 1
)
),
"scale_b must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
);
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise128x128;
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
@ -1105,47 +997,26 @@ _scaled_mxfp4_mxfp4(
const std::optional<Tensor>& bias,
const c10::ScalarType out_dtype,
Tensor& out) {
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
#ifndef USE_ROCM
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
#endif
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
auto K_multiplier = 2;
#ifdef USE_ROCM
// AMD
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
#else
// NVIDIA
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
#endif
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
#ifdef USE_ROCM
// AMD
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
#else
// NVIDIA
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
#endif
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
"For Blockwise scaling both scales should be contiguous");
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
#ifdef USE_ROCM
// AMD
auto scaling_choice_a = ScalingType::BlockWise1x32;
auto scaling_choice_b = ScalingType::BlockWise1x32;
@ -1160,30 +1031,11 @@ _scaled_mxfp4_mxfp4(
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
#else
// NVIDIA
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
// but we have one we need to use. Two clear options are to copy into
// our output (slow), or use a move-assignment-operator (faster).
// However, the compiler can complain about the explicit move preventing
// copy elision because the return from f4f4bf16 is a temporary object.
// So we don't explicitly move, and trust the compiler here...
// In the longer term this should be fixed on the FBGemm side.
out = fbgemm_gpu::f4f4bf16(
mat_a,
mat_b.transpose(-2, -1),
scale_a,
scale_b,
std::nullopt, /* global_scale */
true /* use_mx */
);
return out;
#endif
#endif
}
Tensor&
@ -1308,20 +1160,17 @@ _scaled_mm_cuda_v2_out(
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
}
// Handle fp4 packed-K dimension
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
" but got ", bias->numel());
TORCH_CHECK_VALUE(
K_multiplier * mat_a.sizes()[1] % 16 == 0,
mat_a.sizes()[1] % 16 == 0,
"Expected trailing dimension of mat1 to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes()[0],
"x",
K_multiplier * mat_a.sizes()[1],
mat_a.sizes()[1],
").");
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
mat_b.sizes()[1], ") must be divisible by 16");
// TODO(slayton): Existing checks, not sure if they should really be here.

View File

@ -1881,8 +1881,6 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) {
REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
#if !AT_MAGMA_ENABLED()
@ -1957,6 +1955,8 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const
#endif
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This is a type dispatch function for 'apply_magma_eigh'
// For small inputs result is computed on CPU
void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
@ -2019,10 +2019,10 @@ This is an in-place routine, content of 'input', 'values', 'vectors' is overwrit
For more information see MAGMA's documentation for GEEV routine.
*/
template <typename scalar_t>
void apply_magma_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "Calling torch.linalg.eig with MAGMA requires compiling PyTorch with MAGMA. "
"Either transfer the tensor to the CPU before calling torch.linalg.eig or use cuSolver.");
TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling PyTorch with MAGMA. "
"Either transfer the tensor to the CPU before calling torch.linalg.eig or recompile with MAGMA.");
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == at::kCPU);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
@ -2076,44 +2076,22 @@ TORCH_CHECK(false, "Calling torch.linalg.eig with MAGMA requires compiling PyTor
#endif
}
// MAGMA wrapper: transfers tensors to CPU, calls apply_magma_eig, then copies results back.
void linalg_eig_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors){
// MAGMA doesn't have GPU interface for the eigendecomposition, and it forces us to transfer to CPU
auto eigenvalues_cpu = eigenvalues.cpu();
auto eigenvectors_cpu = eigenvectors.cpu();
auto infos_cpu = infos.cpu();
Tensor input_cpu = at::empty(input.sizes(), input.options().device(kCPU));
input_cpu.transpose_(-2, -1);
input_cpu.copy_(input);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{
apply_magma_eig<scalar_t>(eigenvalues_cpu, eigenvectors_cpu, input_cpu, infos_cpu, compute_eigenvectors);
});
eigenvalues.copy_(eigenvalues_cpu);
eigenvectors.copy_(eigenvectors_cpu);
infos.copy_(infos_cpu);
}
// This is a type dispatching helper function for 'apply_linalg_eig'
void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
// This function calculates the non-symmetric eigendecomposition in-place
// tensors should be in batched column major memory format
// the content of eigenvalues, eigenvectors and infos is overwritten by 'linalg_eig_magma' or
// 'linalg_eig_cusolver_xgeev' both geev routines modify the provided input matrix in-place, therefore we need a copy
// the content of eigenvalues, eigenvectors and infos is overwritten by 'apply_linalg_eig'
// apply_linalg_eig modifies the provided input matrix in-place, therefore we need a copy
// MAGMA doesn't have GPU interface for the eigendecomposition and it forces us to transfer 'input' to CPU
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
auto preferred_backend = at::globalContext().linalgPreferredBackend();
switch (preferred_backend) {
case at::LinalgBackend::Cusolver:
default:
linalg_eig_cusolver_xgeev(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
return;
case at::LinalgBackend::Magma:
break; // MAGMA path handled below
}
#endif
linalg_eig_magma(eigenvalues, eigenvectors, infos, input, compute_eigenvectors);
Tensor input_working_copy = at::empty(input.sizes(), input.options().device(kCPU));
input_working_copy.transpose_(-2, -1); // make input_working_copy to have Fortran contiguous memory layout
input_working_copy.copy_(input);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{
apply_linalg_eig<scalar_t>(eigenvalues, eigenvectors, input_working_copy, infos, compute_eigenvectors);
});
}
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)

View File

@ -1625,126 +1625,6 @@ void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors,
#endif
}
// cuSOLVER Xgeev (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+)
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
template <typename scalar_t>
void apply_xgeev(const Tensor& values, const Tensor& vectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_cuda());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.is_cuda());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_cuda());
int n = cuda_int_cast(input.size(-1), "n");
int lda = std::max<int>(1, n);
auto batch_size = batchCount(input);
if (n == 0 || batch_size == 0) {
// XGeev crashes on empty input, explicitly handle empty input
auto values_shape = IntArrayRef(input.sizes().data(), input.dim() - 1);
values.resize_(values_shape, MemoryFormat::Contiguous);
values.zero_();
if (compute_eigenvectors) {
vectors.resize_(input.sizes(), MemoryFormat::Contiguous);
vectors.zero_();
} else {
vectors.resize_({0});
}
infos.resize_({std::max<int64_t>(1, batch_size)}, MemoryFormat::Contiguous);
infos.zero_();
return;
}
int64_t vectors_stride = 0;
if (compute_eigenvectors){
vectors_stride = matrixStride(vectors);
}
auto values_stride = values.size(-1);
auto vectors_data = vectors.data_ptr<scalar_t>();
auto values_data = values.data_ptr<scalar_t>();
auto infos_data = infos.data_ptr<int>();
cusolverDnParams_t params = nullptr;
TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(&params));
Tensor A_fortran = input.mT().contiguous();
auto* A_data = A_fortran.data_ptr<scalar_t>();
const auto A_stride = matrixStride(A_fortran);
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
const int ldvl = 1; // ldvl >= 1 if jobvl = CUSOLVER_EIG_MODE_NOVECTOR
cusolverEigMode_t jobvl = CUSOLVER_EIG_MODE_NOVECTOR;
cusolverEigMode_t jobvr;
int ldvr;
if (compute_eigenvectors) {
ldvr = n; // ldvr >= n if jobvr = CUSOLVER_EIG_MODE_VECTOR
jobvr = CUSOLVER_EIG_MODE_VECTOR;
}
else {
ldvr = 1; // ldvr >= 1 if jobvr = CUSOLVER_EIG_MODE_NOVECTOR
jobvr = CUSOLVER_EIG_MODE_NOVECTOR;
}
scalar_t* W = values.data_ptr<scalar_t>();
scalar_t* VL = nullptr;
scalar_t* VR = vectors.data_ptr<scalar_t>();
const scalar_t* A_const = A_data;
const scalar_t* W_const = W;
const scalar_t* VL_const = VL;
const scalar_t* VR_const = VR;
size_t ws_dev = 0, ws_host = 0;
at::cuda::solver::xgeev_bufferSize<scalar_t>(
handle, params,
jobvl, jobvr,
n,
A_const, lda,
W_const,
VL_const, ldvl,
VR_const, ldvr,
&ws_dev, &ws_host);
auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
auto work_device_data = device_allocator.allocate(ws_dev);
// use pinned memory for best performance.
auto& host_allocator = *at::cuda::getPinnedMemoryAllocator();
auto work_host_data = host_allocator.allocate(ws_host);
for (decltype(batch_size) i = 0; i < batch_size; ++i) {
scalar_t* Ai = A_data + i * A_stride;
scalar_t* Wi = values_data + i * values_stride;
scalar_t* VLi = nullptr; // xgeev does not support computing left evs
scalar_t* VRi = compute_eigenvectors ? (vectors_data + i * vectors_stride) : nullptr;
int* info = infos_data + i;
at::cuda::solver::xgeev<scalar_t>(
handle, params,
jobvl, jobvr,
n,
Ai, lda,
Wi,
VLi, ldvl,
VRi, ldvr,
static_cast<scalar_t*>(work_device_data.get()), ws_dev,
static_cast<scalar_t*>(work_host_data.get()), ws_host,
info);
}
TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
}
void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eig_cuda", [&] {
apply_xgeev<scalar_t>(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
});
}
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
// The 'apply_' word is used for templated by dtype functions that call an API routine
// underneath. Since the cusolver API has a slightly different structure we do not prepend
// apply_ to this function.

View File

@ -73,11 +73,6 @@ void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other,
Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau);
void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors);
void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors);
void lu_solve_looped_cusolver(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose);
void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots);

View File

@ -1954,336 +1954,6 @@ void xsyevd<c10::complex<double>, double>(
workspaceInBytesOnHost,
info));
}
// cuSOLVER Xgeev bindings (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+)
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
template <>
void xgeev_bufferSize<float>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
const float* A,
int64_t lda,
const float* W,
const float* VL,
int64_t ldvl,
const float* VR,
int64_t ldvr,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
handle, params, jobvl, jobvr, n,
CUDA_R_32F,
reinterpret_cast<const void*>(A),
lda,
CUDA_R_32F,
reinterpret_cast<const void*>(W),
CUDA_R_32F,
reinterpret_cast<const void*>(VL),
ldvl,
CUDA_R_32F,
reinterpret_cast<const void*>(VR),
ldvr,
CUDA_R_32F,
workspaceInBytesOnDevice,
workspaceInBytesOnHost));
}
template <>
void xgeev_bufferSize<double>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
const double* A,
int64_t lda,
const double* W,
const double* VL,
int64_t ldvl,
const double* VR,
int64_t ldvr,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
handle, params, jobvl, jobvr, n,
CUDA_R_64F,
reinterpret_cast<const void*>(A),
lda,
CUDA_R_64F,
reinterpret_cast<const void*>(W),
CUDA_R_64F,
reinterpret_cast<const void*>(VL),
ldvl,
CUDA_R_64F,
reinterpret_cast<const void*>(VR),
ldvr,
CUDA_R_64F,
workspaceInBytesOnDevice,
workspaceInBytesOnHost));
}
template <>
void xgeev_bufferSize<c10::complex<float>>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
const c10::complex<float>* A,
int64_t lda,
const c10::complex<float>* W,
const c10::complex<float>* VL,
int64_t ldvl,
const c10::complex<float>* VR,
int64_t ldvr,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
handle, params, jobvl, jobvr, n,
CUDA_C_32F,
reinterpret_cast<const void*>(A),
lda,
CUDA_C_32F,
reinterpret_cast<const void*>(W),
CUDA_C_32F,
reinterpret_cast<const void*>(VL),
ldvl,
CUDA_C_32F,
reinterpret_cast<const void*>(VR),
ldvr,
CUDA_C_32F,
workspaceInBytesOnDevice,
workspaceInBytesOnHost));
}
template <>
void xgeev_bufferSize<c10::complex<double>>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
const c10::complex<double>* A,
int64_t lda,
const c10::complex<double>* W,
const c10::complex<double>* VL,
int64_t ldvl,
const c10::complex<double>* VR,
int64_t ldvr,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
handle, params, jobvl, jobvr, n,
CUDA_C_64F,
reinterpret_cast<const void*>(A),
lda,
CUDA_C_64F,
reinterpret_cast<const void*>(W),
CUDA_C_64F,
reinterpret_cast<const void*>(VL),
ldvl,
CUDA_C_64F,
reinterpret_cast<const void*>(VR),
ldvr,
CUDA_C_64F,
workspaceInBytesOnDevice,
workspaceInBytesOnHost));
}
template <>
void xgeev<float>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
float* A,
int64_t lda,
float* W,
float* VL,
int64_t ldvl,
float* VR,
int64_t ldvr,
float* bufferOnDevice,
size_t workspaceInBytesOnDevice,
float* bufferOnHost,
size_t workspaceInBytesOnHost,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
handle,
params,
jobvl,
jobvr,
n,
CUDA_R_32F,
reinterpret_cast<void*>(A),
lda,
CUDA_R_32F,
reinterpret_cast<void*>(W),
CUDA_R_32F,
reinterpret_cast<void*>(VL),
ldvl,
CUDA_R_32F,
reinterpret_cast<void*>(VR),
ldvr,
CUDA_R_32F,
reinterpret_cast<void*>(bufferOnDevice),
workspaceInBytesOnDevice,
reinterpret_cast<void*>(bufferOnHost),
workspaceInBytesOnHost,
info));
}
template <>
void xgeev<double>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
double* A,
int64_t lda,
double* W,
double* VL,
int64_t ldvl,
double* VR,
int64_t ldvr,
double* bufferOnDevice,
size_t workspaceInBytesOnDevice,
double* bufferOnHost,
size_t workspaceInBytesOnHost,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
handle,
params,
jobvl,
jobvr,
n,
CUDA_R_64F,
reinterpret_cast<void*>(A),
lda,
CUDA_R_64F,
reinterpret_cast<void*>(W),
CUDA_R_64F,
reinterpret_cast<void*>(VL),
ldvl,
CUDA_R_64F,
reinterpret_cast<void*>(VR),
ldvr,
CUDA_R_64F,
reinterpret_cast<void*>(bufferOnDevice),
workspaceInBytesOnDevice,
reinterpret_cast<void*>(bufferOnHost),
workspaceInBytesOnHost,
info));
}
template <>
void xgeev<c10::complex<float>>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
c10::complex<float>* A,
int64_t lda,
c10::complex<float>* W,
c10::complex<float>* VL,
int64_t ldvl,
c10::complex<float>* VR,
int64_t ldvr,
c10::complex<float>* bufferOnDevice,
size_t workspaceInBytesOnDevice,
c10::complex<float>* bufferOnHost,
size_t workspaceInBytesOnHost,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
handle,
params,
jobvl,
jobvr,
n,
CUDA_C_32F,
reinterpret_cast<void*>(A),
lda,
CUDA_C_32F,
reinterpret_cast<void*>(W),
CUDA_C_32F,
reinterpret_cast<void*>(VL),
ldvl,
CUDA_C_32F,
reinterpret_cast<void*>(VR),
ldvr,
CUDA_C_32F,
reinterpret_cast<void*>(bufferOnDevice),
workspaceInBytesOnDevice,
reinterpret_cast<void*>(bufferOnHost),
workspaceInBytesOnHost,
info));
}
template <>
void xgeev<c10::complex<double>>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
c10::complex<double>* A,
int64_t lda,
c10::complex<double>* W,
c10::complex<double>* VL,
int64_t ldvl,
c10::complex<double>* VR,
int64_t ldvr,
c10::complex<double>* bufferOnDevice,
size_t workspaceInBytesOnDevice,
c10::complex<double>* bufferOnHost,
size_t workspaceInBytesOnHost,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
handle,
params,
jobvl,
jobvr,
n,
CUDA_C_64F,
reinterpret_cast<void*>(A),
lda,
CUDA_C_64F,
reinterpret_cast<void*>(W),
CUDA_C_64F,
reinterpret_cast<void*>(VL),
ldvl,
CUDA_C_64F,
reinterpret_cast<void*>(VR),
ldvr,
CUDA_C_64F,
reinterpret_cast<void*>(bufferOnDevice),
workspaceInBytesOnDevice,
reinterpret_cast<void*>(bufferOnHost),
workspaceInBytesOnHost,
info));
}
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
#endif // USE_CUSOLVER_64_BIT
#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED

View File

@ -674,66 +674,6 @@ template <>
void xsyevd<c10::complex<double>, double>(
CUDASOLVER_XSYEVD_ARGTYPES(c10::complex<double>, double));
// cuSOLVER Xgeev (non-Hermitian eigen decomposition, CUDA >= 12.8)
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
#define CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t) \
cusolverDnHandle_t handle, cusolverDnParams_t params, \
cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n, \
const scalar_t* A, int64_t lda, const scalar_t* W, \
const scalar_t* VL, int64_t ldvl, const scalar_t* VR, int64_t ldvr, \
size_t* workspaceInBytesOnDevice, size_t* workspaceInBytesOnHost
template <class scalar_t>
void xgeev_bufferSize(
CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t)) {
static_assert(false&&sizeof(scalar_t),
"at::cuda::solver::xgeev_bufferSize: not implemented");
}
template <>
void xgeev_bufferSize<float>(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(float));
template <>
void xgeev_bufferSize<double>(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(double));
template <>
void xgeev_bufferSize<c10::complex<float>>(
CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex<float>));
template <>
void xgeev_bufferSize<c10::complex<double>>(
CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex<double>));
#define CUDASOLVER_XGEEV_ARGTYPES(scalar_t) \
cusolverDnHandle_t handle, cusolverDnParams_t params, \
cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n, scalar_t *A, \
int64_t lda, scalar_t *W, scalar_t *VL, int64_t ldvl, scalar_t *VR, int64_t ldvr,\
scalar_t *bufferOnDevice, size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost,\
size_t workspaceInBytesOnHost, int *info
template <class scalar_t>
void xgeev(CUDASOLVER_XGEEV_ARGTYPES(scalar_t)) {
static_assert(false&&sizeof(scalar_t),
"at::cuda::solver::xgeev: not implemented");
}
template <>
void xgeev<float>(CUDASOLVER_XGEEV_ARGTYPES(float));
template <>
void xgeev<double>(CUDASOLVER_XGEEV_ARGTYPES(double));
template <>
void xgeev<c10::complex<float>>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex<float>));
template <>
void xgeev<c10::complex<double>>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex<double>));
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
#endif // USE_CUSOLVER_64_BIT
#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED

View File

@ -157,10 +157,10 @@ bool onednn_strides_check(const Tensor& src) {
return true;
dnnl_dims_t blocks = {0};
std::array<int, DNNL_MAX_NDIMS> perm = {0};
int perm[DNNL_MAX_NDIMS] = {0};
for (int d = 0; d < md_ndims; ++d) {
// no strides check needed for empty tensor
if ((*md_padded_dims)[d] == 0)
if (md_padded_dims[d] == nullptr)
return true;
// no strides verification for runtime dims
@ -178,15 +178,14 @@ bool onednn_strides_check(const Tensor& src) {
// A custom comparator to yield linear order on perm
auto idx_sorter = [&](const int a, const int b) -> bool {
if (strides[a] == strides[b] &&
(*md_padded_dims)[a] == (*md_padded_dims)[b])
if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b])
return a < b;
else if (strides[a] == strides[b])
return (*md_padded_dims)[a] < (*md_padded_dims)[b];
return md_padded_dims[a] < md_padded_dims[b];
else
return strides[a] < strides[b];
};
std::sort(perm.begin(), perm.begin() + md_ndims, idx_sorter);
std::sort(perm, perm + md_ndims, idx_sorter);
auto min_stride = block_size;
for (int idx = 0; idx < md_ndims; ++idx) {
@ -200,10 +199,9 @@ bool onednn_strides_check(const Tensor& src) {
return false;
// update min_stride for next iteration
const auto padded_dim = (*md_padded_dims)[d];
const auto padded_dim = *md_padded_dims[d];
min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
}
return true;
}

View File

@ -370,7 +370,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
onValue:-1.0f
offValue:0.0f
name:nil];
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]);
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType);
if (isWeightsArrayValid) {
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
secondaryTensor:weightTensor
@ -705,7 +705,6 @@ static void smooth_l1_loss_template(const Tensor& input,
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
TORCH_CHECK(input.is_mps());
TORCH_CHECK(target.is_mps());
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
if ((input.numel() == 0) || (target.numel() == 0)) {
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
return;
@ -772,7 +771,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]];
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32];
// xn - yn
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
secondaryTensor:targetTensor
@ -798,8 +797,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
name:@"lossTensor"];
MPSGraphTensor* outputTensor = lossTensor;
if (reduction == Reduction::Mean) {
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel()
dataType:[lossTensor dataType]];
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32];
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
}
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor

View File

@ -84,9 +84,6 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out(const Tensor& self,
Tensor& output,
Tensor& save_mean,
Tensor& save_var) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Long batch norm is not supported with MPS");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
"Batch norm for complex is not supported for MPS");
using namespace at::native::mps;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
@ -921,7 +918,6 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(const Tensor& input,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int axis = input_ndim - normalized_ndim;
MPSStream* stream = getCurrentMPSStream();
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS");
@autoreleasepool {
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
// which kernel variant to use based on the normalized axis N size

View File

@ -4800,12 +4800,6 @@
CompositeExplicitAutograd: rand_like
autogen: rand_like.out
- func: rand_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
CompositeExplicitAutograd: rand_like
autogen: rand_like.generator_out
- func: randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
@ -4854,14 +4848,6 @@
CompositeExplicitAutograd: randint_like
autogen: randint_like.out
- func: randint_like.generator(Tensor self, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd: randint_like
autogen: randint_like.generator_out
- func: randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
@ -4870,14 +4856,6 @@
CompositeExplicitAutograd: randint_like
autogen: randint_like.Tensor_out
- func: randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd: randint_like
autogen: randint_like.Tensor_generator_out
- func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
@ -4886,14 +4864,6 @@
CompositeExplicitAutograd: randint_like
autogen: randint_like.low_dtype_out
- func: randint_like.low_generator_dtype(Tensor self, SymInt low, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd: randint_like
autogen: randint_like.low_generator_dtype_out
- func: randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
tags: [core, nondeterministic_seeded]
dispatch:
@ -4934,14 +4904,6 @@
CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like
autogen: randn_like.out
- func: randn_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like
autogen: randn_like.generator_out
- func: randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
tags: [core, nondeterministic_seeded]
dispatch:

View File

@ -202,7 +202,6 @@ supported:
- select_backward
- _trilinear
- linalg_pinv.atol_rtol_tensor
- svd
- logsumexp.out
symint:
- empty.memory_format

View File

@ -1,63 +0,0 @@
#include <ATen/xpu/PeerToPeerAccess.h>
#include <ATen/xpu/XPUContext.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <c10/xpu/XPUCachingAllocator.h>
namespace at::xpu {
// p2pAccessEnabled_ is a flattened 2D matrix of size [num_devices x
// num_devices].
// Each element represents whether device[i] can access device[j]:
// 1 -> access allowed
// 0 -> access not allowed
// -1 -> unknown (not yet queried)
static std::vector<int8_t> p2pAccessEnabled_;
namespace detail {
// Initializes the peer-to-peer (P2P) access capability cache.
void init_p2p_access_cache(c10::DeviceIndex num_devices) {
// By default, each device can always access itself (diagonal entries = 1).
// For simplicity, all entries are initialized to -1 except the diagonal.
static bool once [[maybe_unused]] = [num_devices]() {
p2pAccessEnabled_.clear();
p2pAccessEnabled_.resize(num_devices * num_devices, -1);
for (const auto i : c10::irange(num_devices)) {
p2pAccessEnabled_[i * num_devices + i] = 1;
}
return true;
}();
}
} // namespace detail
bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
check_device_index(dev);
check_device_index(dev_to_access);
auto& cache =
p2pAccessEnabled_[dev * c10::xpu::device_count() + dev_to_access];
if (cache != -1) {
return static_cast<bool>(cache);
}
// Query the hardware to determine if P2P access is supported
cache = static_cast<int8_t>(
c10::xpu::get_raw_device(dev).ext_oneapi_can_access_peer(
c10::xpu::get_raw_device(dev_to_access),
sycl::ext::oneapi::peer_access::access_supported));
if (cache) {
XPUCachingAllocator::enablePeerAccess(dev, dev_to_access);
}
return static_cast<bool>(cache);
}
} // namespace at::xpu

View File

@ -1,15 +0,0 @@
#pragma once
#include <c10/core/Device.h>
#include <c10/macros/Macros.h>
namespace at::xpu {
namespace detail {
void init_p2p_access_cache(c10::DeviceIndex num_devices);
} // namespace detail
TORCH_XPU_API bool get_p2p_access(
c10::DeviceIndex dev,
c10::DeviceIndex dev_to_access);
} // namespace at::xpu

View File

@ -1,4 +1,3 @@
#include <ATen/xpu/PeerToPeerAccess.h>
#include <ATen/xpu/PinnedMemoryAllocator.h>
#include <ATen/xpu/XPUContext.h>
#include <ATen/xpu/XPUDevice.h>
@ -13,7 +12,6 @@ void XPUHooks::init() const {
C10_LOG_API_USAGE_ONCE("aten.init.xpu");
const auto device_count = c10::xpu::device_count_ensure_non_zero();
c10::xpu::XPUCachingAllocator::init(device_count);
at::xpu::detail::init_p2p_access_cache(device_count);
}
bool XPUHooks::hasXPU() const {

View File

@ -11,11 +11,6 @@ def remove_cuda(config_list):
return [config for config in config_list if cuda_config not in config]
def remove_cpu(config_list):
cpu_config = {"device": "cpu"}
return [config for config in config_list if cpu_config not in config]
# Configs for conv-1d ops
conv_1d_configs_short = op_bench.config_list(
attr_names=["IC", "OC", "kernel", "stride", "N", "L"],
@ -132,18 +127,6 @@ conv_3d_configs_short = op_bench.config_list(
},
tags=["short"],
)
conv_3d_configs_long = op_bench.cross_product_configs(
IC=[16, 32],
OC=[32, 64],
kernel=[3, 5],
stride=[1, 2],
N=[1],
D=[128],
H=[128],
W=[128],
device=["cpu", "cuda"],
tags=["long"],
)
linear_configs_short = op_bench.config_list(
attr_names=["N", "IN", "OUT"],

View File

@ -38,10 +38,6 @@ class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_test(
configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
)
op_bench.generate_pt_gradient_test(
configs.remove_cpu(configs.conv_1d_configs_short + configs.conv_1d_configs_long),
Conv1dBenchmark,
)
if not torch.backends.mkldnn.is_acl_available():
@ -107,20 +103,6 @@ op_bench.generate_pt_test(
configs.conv_2d_pw_configs_short + configs.conv_2d_pw_configs_long,
Conv2dPointwiseBenchmark,
)
op_bench.generate_pt_gradient_test(
configs.remove_cpu(configs.conv_2d_configs_short + configs.conv_2d_configs_long),
Conv2dBenchmark,
)
op_bench.generate_pt_gradient_test(
configs.remove_cpu(configs.conv_2d_configs_short + configs.conv_2d_configs_long),
ConvTranspose2dBenchmark,
)
op_bench.generate_pt_gradient_test(
configs.remove_cpu(
configs.conv_2d_pw_configs_short + configs.conv_2d_pw_configs_long
),
Conv2dPointwiseBenchmark,
)
"""
@ -152,12 +134,6 @@ class ConvTranspose3dBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_test(configs.conv_3d_configs_short, Conv3dBenchmark)
op_bench.generate_pt_test(configs.conv_3d_configs_short, ConvTranspose3dBenchmark)
op_bench.generate_pt_gradient_test(
configs.remove_cpu(configs.conv_3d_configs_long), Conv3dBenchmark
)
op_bench.generate_pt_gradient_test(
configs.remove_cpu(configs.conv_3d_configs_long), ConvTranspose3dBenchmark
)
if __name__ == "__main__":

View File

@ -929,7 +929,6 @@ libtorch_python_core_sources = [
"torch/csrc/dynamo/guards.cpp",
"torch/csrc/dynamo/utils.cpp",
"torch/csrc/dynamo/init.cpp",
"torch/csrc/dynamo/stackref_bridge.c",
"torch/csrc/functorch/init.cpp",
"torch/csrc/fx/node.cpp",
"torch/csrc/mps/Module.cpp",

View File

@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always reuse CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(

View File

@ -21,20 +21,13 @@ using stream_set = ska::flat_hash_set<xpu::XPUStream>;
struct Block;
typedef bool (*Comparison)(const Block*, const Block*);
bool BlockComparatorSize(const Block* a, const Block* b);
bool BlockComparatorAddress(const Block* a, const Block* b);
struct BlockPool {
BlockPool(bool small)
: blocks(BlockComparatorSize),
unmapped(BlockComparatorAddress),
is_small(small) {}
BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {}
std::set<Block*, Comparison> blocks;
std::set<Block*, Comparison> unmapped;
const bool is_small;
};
struct ExpandableSegment;
struct Block {
DeviceIndex device;
sycl::queue* queue{nullptr}; // underlying queue of the allocation stream
@ -44,11 +37,9 @@ struct Block {
BlockPool* pool{nullptr}; // owning memory pool
void* ptr{nullptr}; // memory address
bool allocated{false}; // in-use flag
bool mapped{true}; // True if this Block is backed by physical pages
Block* prev{nullptr}; // prev block if split from a larger allocation
Block* next{nullptr}; // next block if split from a larger allocation
int event_count{0}; // number of outstanding XPU events
ExpandableSegment* expandable_segment{nullptr}; // owning expandable segment
Block(
DeviceIndex device,
@ -75,20 +66,6 @@ struct Block {
bool is_split() const {
return (prev != nullptr) || (next != nullptr);
}
// Inserts this block between two existing blocks with [before, this, after].
void splice(Block* before, Block* after) {
if (before) {
TORCH_INTERNAL_ASSERT(before->next == after);
before->next = this;
}
prev = before;
if (after) {
TORCH_INTERNAL_ASSERT(after->prev == before);
after->prev = this;
}
next = after;
}
};
bool BlockComparatorSize(const Block* a, const Block* b) {
@ -103,221 +80,6 @@ bool BlockComparatorSize(const Block* a, const Block* b) {
reinterpret_cast<uintptr_t>(b->ptr);
}
bool BlockComparatorAddress(const Block* a, const Block* b) {
if (a->queue != b->queue) {
return reinterpret_cast<uintptr_t>(a->queue) <
reinterpret_cast<uintptr_t>(b->queue);
}
return reinterpret_cast<uintptr_t>(a->ptr) <
reinterpret_cast<uintptr_t>(b->ptr);
}
// Represents a contiguous virtual memory segment mapped for allocation.
struct SegmentRange {
SegmentRange(void* addr, size_t bytes)
: ptr(static_cast<char*>(addr)), size(bytes) {}
char* ptr; // Starting address of the mapped range.
size_t size; // Size in bytes of the mapped range.
};
struct ExpandableSegment {
ExpandableSegment(
c10::DeviceIndex device,
std::optional<sycl::queue*> queue,
size_t segment_size,
std::vector<c10::DeviceIndex> peers)
: device_(device),
queue_(queue),
// 2MB for small pool, 20MB for large pool
segment_size_(segment_size),
peers_(std::move(peers)) {
const auto device_total =
c10::xpu::get_raw_device(device)
.get_info<sycl::info::device::global_mem_size>();
// The extra 1/8 allows flexibility for remapping or moving pages within the
// segment when unmapping earlier regions.
constexpr float kVirtualMemOversubscriptFactor = 1.125f; // 1 + 1/8
max_handles_ = numSegments(device_total * kVirtualMemOversubscriptFactor);
ptr_ = sycl::ext::oneapi::experimental::reserve_virtual_mem(
segment_size_ * max_handles_, xpu::get_device_context());
}
C10_DISABLE_COPY_AND_ASSIGN(ExpandableSegment);
ExpandableSegment(ExpandableSegment&&) = delete;
ExpandableSegment& operator=(ExpandableSegment&&) = delete;
// Maps a virtual memory range to physical memory.
SegmentRange map(SegmentRange range) {
auto begin = segmentLeft(range.ptr);
auto end = segmentRight(range.ptr + range.size);
TORCH_INTERNAL_ASSERT(ptr() + begin * segment_size_ == range.ptr);
if (begin == end) {
return rangeFromHandles(begin, end);
}
// Ensure handles_ vector is large enough to hold all segments.
if (end > handles_.size()) {
handles_.resize(end, std::nullopt);
}
// Allocate and map physical memory for each segment.
for (const auto i : c10::irange(begin, end)) {
TORCH_INTERNAL_ASSERT(!handles_.at(i));
try {
// Allocate physical memory for each segment. Construct the physical_mem
// in-place to avoid copies.
handles_.at(i).emplace(
xpu::get_raw_device(device_),
xpu::get_device_context(),
segment_size_);
// Map the allocated physical memory into the virtual address space.
handles_.at(i).value().map(
ptr_ + i * segment_size_,
segment_size_,
sycl::ext::oneapi::experimental::address_access_mode::read_write);
} catch (const sycl::exception& e) {
// Allocation failure: typically sycl::errc::memory_allocation.
// Mapping failure: typically sycl::errc::runtime (e.g., OOM due to
// over-subscription).
// Note: constructing physical_mem may over-subscribe device memory but
// not immediately trigger OOM. The actual OOM can occur during map().
// Roll back all segments allocated or mapped in this operation.
handles_.at(i) = std::nullopt;
for (const auto j : c10::irange(begin, i)) {
sycl::ext::oneapi::experimental::unmap(
reinterpret_cast<void*>(ptr_ + segment_size_ * j),
segment_size_,
xpu::get_device_context());
handles_.at(j) = std::nullopt;
}
trimHandles();
return rangeFromHandles(begin, begin);
}
}
return rangeFromHandles(begin, end);
}
// Unmap a virtual memory range from physical memory.
SegmentRange unmap(SegmentRange range) {
auto begin = segmentRight(range.ptr);
auto end = segmentLeft(range.ptr + range.size);
if (begin >= end) {
return SegmentRange{range.ptr, 0};
}
unmapHandles(begin, end);
return rangeFromHandles(begin, end);
}
// Returns the base pointer of the virtual memory segment.
char* ptr() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<char*>(ptr_);
}
// Returns the total size of the virtual memory segment.
size_t size() const {
return max_handles_ * segment_size_;
}
~ExpandableSegment() {
forEachAllocatedRange(
[&](size_t begin, size_t end) { unmapHandles(begin, end); });
sycl::ext::oneapi::experimental::free_virtual_mem(
ptr_, segment_size_ * max_handles_, xpu::get_device_context());
}
private:
// Unmaps the physical memory handles in the range [begin, end) from the
// segment.
void unmapHandles(size_t begin, size_t end) {
// Currently, we don't support IPC shared memory with expandable segments.
TORCH_INTERNAL_ASSERT(queue_);
// As explained in Note [Safe to Free Blocks on BlockPool], additional
// synchronization is unnecessary here because the memory is already safe to
// release.
for (const auto i : c10::irange(begin, end)) {
// Note: physical_mem's destructor does NOT automatically unmap any mapped
// ranges. Users must explicitly call unmap on all ranges before
// destroying the physical_mem object.
sycl::ext::oneapi::experimental::unmap(
reinterpret_cast<void*>(ptr_ + segment_size_ * i),
segment_size_,
xpu::get_device_context());
// Here physical_mem object is being destructed.
handles_.at(i) = std::nullopt;
}
trimHandles();
}
// Remove trailing unused handles from the end of handles_.
void trimHandles() {
while (!handles_.empty() && !handles_.back()) {
handles_.pop_back();
}
}
// Iterates over all contiguous ranges of allocated segments in `handles_`,
// and invokes the provided function `fn(start, end)` for each range.
// Each range is defined as a half-open interval [start, end).
void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
size_t start = 0;
for (const auto i : c10::irange(handles_.size())) {
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
start = i;
}
if (handles_.at(i) && (i + 1 == handles_.size() || !handles_.at(i + 1))) {
fn(start, i + 1);
}
}
}
// Returns the number of full segments required to cover `size` bytes.
// Rounds up to ensure partial segments are counted.
size_t numSegments(size_t size) const {
return (size + segment_size_ - 1) / segment_size_;
}
// Returns the index of the segment that contains the pointer `p`,
// relative to the base pointer `ptr_`. This is the *inclusive* lower bound
// of the segment that includes `p`.
size_t segmentLeft(char* p) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
size_t offset = p - ptr();
return offset / segment_size_;
}
// Returns the index of the segment just *past* the one containing pointer
// `p`, relative to the base pointer `ptr_`. This is the *exclusive* upper
// bound, useful for [begin, end) style ranges.
// If `p` lies exactly on a segment boundary, this is equal to segmentLeft(p).
// Otherwise, it rounds up and returns segmentLeft(p) + 1.
size_t segmentRight(char* p) const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
size_t offset = p - ptr();
return numSegments(offset);
}
// Constructs a SegmentRange spanning indices [start, end).
SegmentRange rangeFromHandles(size_t begin, size_t end) {
return SegmentRange(
ptr() + segment_size_ * begin, segment_size_ * (end - begin));
}
c10::DeviceIndex device_{-1};
std::optional<sycl::queue*> queue_;
// Virtual memory address used for reservation.
uintptr_t ptr_{0};
// Size of each segment in bytes.
size_t segment_size_{0};
// Maximum number of segments that can be allocated in this segment.
size_t max_handles_{0};
// Physical memory handles for the segments.
std::vector<std::optional<sycl::ext::oneapi::experimental::physical_mem>>
handles_{};
// Peer devices on which this memory could be accessible, reserved.
std::vector<c10::DeviceIndex> peers_{};
};
struct AllocParams {
AllocParams(
DeviceIndex device,
@ -363,12 +125,10 @@ class DeviceCachingAllocator {
DeviceIndex device_index;
size_t allowed_memory_maximum = 0;
bool set_fraction = false;
std::vector<ExpandableSegment*> expandable_segments;
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
if (!src || src->allocated || src->event_count > 0 ||
!src->stream_uses.empty() || dst->mapped != src->mapped) {
!src->stream_uses.empty()) {
return 0;
}
@ -387,8 +147,7 @@ class DeviceCachingAllocator {
}
const size_t subsumed_size = src->size;
dst->size += subsumed_size;
auto erased =
src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src);
auto erased = pool.blocks.erase(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
delete src;
@ -471,175 +230,12 @@ class DeviceCachingAllocator {
}
}
// Finds the first (lowest-address) block in any segment that has sufficient
// contiguous free virtual address space to satisfy `size`. The available
// space may span multiple adjacent blocks, which can include both free and
// unmapped segments.
Block* find_expandable_block(
c10::DeviceIndex device,
sycl::queue* queue,
BlockPool* pool,
size_t size) {
Block key(device, queue, 0);
auto allocatable = [](Block* b) {
return b && !b->allocated && b->event_count == 0 &&
b->stream_uses.empty();
};
auto has_available_address_space = [&](Block* b) {
size_t bytes = 0;
while (bytes < size && allocatable(b)) {
bytes += b->size;
b = b->next;
}
return bytes >= size;
};
for (auto it = pool->unmapped.lower_bound(&key);
it != pool->unmapped.end() && (*it)->queue == queue;
++it) {
Block* c = *it;
// The unmapped block might have a free mapped block right before it.
// By starting from the previous block, we can use both:
// [Free Mapped Block] + [Unmapped Block] = More contiguous space
if (allocatable(c->prev)) {
c = c->prev;
}
if (has_available_address_space(c)) {
return c;
}
}
auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
expandable_segments.emplace_back(new ExpandableSegment(
device, queue, segment_size, devices_with_peer_access));
ExpandableSegment* es = expandable_segments.back();
Block* candidate = new Block(device, queue, es->size(), pool, es->ptr());
candidate->mapped = false;
candidate->expandable_segment = es;
pool->unmapped.insert(candidate);
return candidate;
}
bool map_block(Block* to_map, size_t size) {
TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size);
auto mapped_range =
to_map->expandable_segment->map(SegmentRange{to_map->ptr, size});
// Failed to map the memory
if (mapped_range.size == 0) {
return false;
}
TORCH_INTERNAL_ASSERT(
mapped_range.ptr == to_map->ptr && mapped_range.size >= size);
BlockPool& pool = *to_map->pool;
pool.unmapped.erase(to_map);
to_map->mapped = true;
if (mapped_range.size < to_map->size) {
// to_map -> remaining -> to_map->next(?)
Block* remaining = new Block(
to_map->device,
to_map->queue,
to_map->size - mapped_range.size,
&pool,
static_cast<char*>(to_map->ptr) + mapped_range.size);
remaining->mapped = false;
remaining->expandable_segment = to_map->expandable_segment;
remaining->splice(to_map, to_map->next);
pool.unmapped.insert(remaining);
to_map->size = mapped_range.size;
}
try_merge_blocks(to_map, to_map->prev, pool);
try_merge_blocks(to_map, to_map->next, pool);
pool.blocks.insert(to_map);
StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(mapped_range.size);
});
return true;
}
Block* try_allocate_expandable_block(
c10::DeviceIndex device,
sycl::queue* queue,
BlockPool* pool,
size_t size) {
// Candidate points to the start of a chain of contiguous blocks with
// sufficient virtual address space (>= size). The chain may consist of:
// Case 1: [Unmapped Block] -> null
// Case 2: [Unmapped Block] -> [Free Mapped Block]
// Case 3: [Free Mapped Block] -> [Unmapped Block]
Block* candidate = find_expandable_block(device, queue, pool, size);
// Map first block if unmapped (Case 1 & 2), use std::min to avoid
// over-mapping.
if (!candidate->mapped &&
!map_block(candidate, std::min(candidate->size, size))) {
return nullptr;
}
TORCH_INTERNAL_ASSERT(candidate->mapped);
// Map additional blocks until we have enough continuous space (Case 3).
// Each map_block() call merges newly mapped blocks with adjacent free
// blocks
while (candidate->size < size) {
auto remaining = size - candidate->size;
auto new_candidate = candidate->next;
// Map only what we need from the `new_candidate` block.
if (!map_block(new_candidate, std::min(remaining, new_candidate->size))) {
return nullptr;
}
candidate = new_candidate;
}
// Remove from the free pool; block will be marked as `allocated` in
// alloc_found_block()
pool->blocks.erase(candidate);
return candidate;
}
bool get_free_block(AllocParams& p) {
BlockPool& pool = *p.pool;
auto it = pool.blocks.lower_bound(&p.search_key);
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
return false;
}
if ((*it)->expandable_segment) {
if (AcceleratorAllocatorConfig::use_expandable_segments()) {
// When expandable segments are enabled, consider both the current block
// and any immediately adjacent unmapped region as a single expandable
// area. For "best fit" allocation, we use the total expandable size
// instead of just the block's current size, so that blocks which can
// grow into a larger contiguous range are preferred.
auto expandable_size = [](Block* b) {
// b->next may belong to pool.unmapped (reserved but not mapped)
return b->size + (b->next && !b->next->mapped ? b->next->size : 0);
};
auto next = it;
next++;
// Looks for the best fit block with expandable size.
while ((*it)->expandable_segment && next != pool.blocks.end() &&
(*next)->queue == p.queue() &&
expandable_size(*next) < expandable_size(*it)) {
it = next++;
}
} else {
// Expandable segments were previously enabled, but are now disabled
// (e.g. to avoid IPC issues). Skip any expandable blocks and only
// find from regular non-expandable segments.
do {
it++;
} while (it != pool.blocks.end() && (*it)->expandable_segment &&
(*it)->queue == p.queue());
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
return false;
}
}
}
p.block = *it;
pool.blocks.erase(it);
return true;
@ -656,10 +252,6 @@ class DeviceCachingAllocator {
size >
allowed_memory_maximum) {
return false;
} else if (AcceleratorAllocatorConfig::use_expandable_segments()) {
p.block =
try_allocate_expandable_block(device, p.queue(), p.pool, p.size());
return bool(p.block);
}
void* ptr = sycl::aligned_alloc_device(
kDeviceAlignment,
@ -673,7 +265,6 @@ class DeviceCachingAllocator {
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(size);
});
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
return true;
}
@ -692,27 +283,6 @@ class DeviceCachingAllocator {
xpu_events.clear();
}
void release_expandable_segment(Block* block) {
// See Note [Safe to Free Blocks on BlockPool], additional synchronization
// is unnecessary here because this function is only called by
// release_cached_blocks().
TORCH_INTERNAL_ASSERT(
block->size == block->expandable_segment->size(),
"block disagrees with segment");
TORCH_INTERNAL_ASSERT(!block->mapped);
auto it = std::find(
expandable_segments.begin(),
expandable_segments.end(),
block->expandable_segment);
TORCH_INTERNAL_ASSERT(it != expandable_segments.end());
expandable_segments.erase(it);
block->pool->unmapped.erase(block);
delete block->expandable_segment;
delete block;
}
void release_block(Block* block) {
/*
* Note [Safe to Free Blocks on BlockPool]
@ -723,7 +293,6 @@ class DeviceCachingAllocator {
* We have to do a device-level synchronization before free these blocks to
* guarantee that all kernels can access to the blocks have finished.
*/
TORCH_INTERNAL_ASSERT(!block->expandable_segment);
sycl::free(block->ptr, xpu::get_device_context());
auto* pool = block->pool;
pool->blocks.erase(block);
@ -736,78 +305,13 @@ class DeviceCachingAllocator {
delete block;
}
void unmap_block(Block* block) {
auto unmapped =
block->expandable_segment->unmap(SegmentRange{block->ptr, block->size});
if (unmapped.size == 0) {
return;
}
block->pool->blocks.erase(block);
ptrdiff_t before_size = unmapped.ptr - static_cast<char*>(block->ptr);
if (before_size > 0) {
// If the actual unmapped region starts after block->ptr due to alignment,
// the region before unmapped.ptr is still mapped.
// [Prev Block?] -> [Before Block] -> [Unmapped Block]
Block* before_free = new Block(
block->device, block->queue, before_size, block->pool, block->ptr);
before_free->expandable_segment = block->expandable_segment;
before_free->splice(block->prev, block);
block->pool->blocks.insert(before_free);
}
auto after_size = block->size - (before_size + unmapped.size);
if (after_size > 0) {
// If the actual unmapped region ends before block->ptr + block->size,
// the region after (unmapped.ptr + unmapped.size) is still mapped.
// [Unmapped Block] -> [After Block] -> [Next Block?]
Block* after_free = new Block(
block->device,
block->queue,
after_size,
block->pool,
unmapped.ptr + unmapped.size);
after_free->expandable_segment = block->expandable_segment;
after_free->splice(block, block->next);
block->pool->blocks.insert(after_free);
}
// [Before Mapped Block?] -> [Unmapped Block] -> [After Mapped Block?]
block->ptr = unmapped.ptr;
block->size = unmapped.size;
block->mapped = false;
try_merge_blocks(block, block->prev, *block->pool);
try_merge_blocks(block, block->next, *block->pool);
block->pool->unmapped.insert(block);
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].decrease(unmapped.size);
});
}
void release_blocks(BlockPool& pool) {
std::vector<Block*> to_unmap;
// Frees all non-split blocks in the given pool.
auto it = pool.blocks.begin();
while (it != pool.blocks.end()) {
Block* block = *it;
++it;
if (block->expandable_segment) {
// unmap_block() modifies the free pool, so collect items to free first
// to avoid iterator invalidation.
to_unmap.push_back(block);
} else if (!block->prev && !block->next) {
release_block(block);
}
}
for (Block* block : to_unmap) {
unmap_block(block);
// After unmap_block(), expandable segment blocks with no neighbors are
// also released.
if (!block->prev && !block->next) {
release_expandable_segment(block);
release_block(block);
}
}
}
@ -824,8 +328,7 @@ class DeviceCachingAllocator {
bool should_split(const Block* block, size_t size) {
size_t remaining = block->size - size;
if (block->pool->is_small ||
AcceleratorAllocatorConfig::use_expandable_segments()) {
if (block->pool->is_small) {
return remaining >= kMinBlockSize;
} else {
return remaining > kSmallSize;
@ -858,7 +361,6 @@ class DeviceCachingAllocator {
remaining = block;
block = new Block(device, queue, size, pool, block->ptr);
block->expandable_segment = remaining->expandable_segment;
block->prev = remaining->prev;
if (block->prev) {
block->prev->next = block;
@ -1097,15 +599,6 @@ class XPUAllocator : public DeviceAllocator {
return block;
}
void assertValidDevice(DeviceIndex device) {
const auto device_num = device_allocators.size();
TORCH_CHECK(
0 <= device && device < static_cast<int64_t>(device_num),
"Invalid device argument ",
device,
": did you call init?");
}
public:
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
@ -1218,6 +711,15 @@ class XPUAllocator : public DeviceAllocator {
xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
}
void assertValidDevice(DeviceIndex device) {
const auto device_num = device_allocators.size();
TORCH_CHECK(
0 <= device && device < static_cast<int64_t>(device_num),
"Invalid device argument ",
device,
": did you call init?");
}
DeviceStats getDeviceStats(DeviceIndex device) override {
assertValidDevice(device);
return device_allocators[device]->getStats();
@ -1233,13 +735,6 @@ class XPUAllocator : public DeviceAllocator {
device_allocators[device]->resetAccumulatedStats();
}
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
assertValidDevice(dev);
assertValidDevice(dev_to_access);
c10::xpu::get_raw_device(dev).ext_oneapi_enable_peer_access(
c10::xpu::get_raw_device(dev_to_access));
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();
@ -1298,10 +793,6 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
return allocator.recordStream(dataPtr, stream);
}
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
return allocator.enablePeerAccess(dev, dev_to_access);
}
double getMemoryFraction(DeviceIndex device) {
return allocator.getMemoryFraction(device);
}

View File

@ -25,10 +25,6 @@ C10_XPU_API void raw_delete(void* ptr);
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
C10_XPU_API void enablePeerAccess(
c10::DeviceIndex dev,
c10::DeviceIndex dev_to_access);
C10_XPU_API double getMemoryFraction(DeviceIndex device);
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);

View File

@ -206,43 +206,19 @@ templates_path = [
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"),
]
# TODO: document these and remove them from here.
# Fixes the duplicated
autosummary_filename_map = {
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
"torch.mtia.stream": "torch.mtia.stream_function",
"torch.mtia.Stream": "torch.mtia.Stream_class",
"torch.cpu.stream": "torch.cpu.stream_function",
"torch.cpu.Stream": "torch.cpu.Stream_class",
"torch.cuda.stream": "torch.cuda.stream_function",
"torch.cuda.Stream": "torch.cuda.Stream_class",
"torch.xpu.stream": "torch.xpu.stream_function",
"torch.xpu.Stream": "torch.xpu.Stream_class",
}
coverage_ignore_functions = [
# torch
"typename",
# torch.cuda._sanitizer
"zip_arguments",
"zip_by_key",
# torch.distributed.autograd
"is_available",
# torch.distributed.checkpoint.state_dict
"gc_context",
# torch.distributed.elastic.events
"record_rdzv_event",
# torch.distributed.elastic.metrics
"initialize_metrics",
# torch.distributed.elastic.rendezvous.registry
@ -3219,11 +3195,6 @@ autodoc_type_aliases = {
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
autodoc_default_options = {
"exclude-members": "from_bytes, to_bytes",
}
# -- katex javascript in header
#
# def setup(app):

View File

@ -253,6 +253,7 @@ regular full-precision tensor.
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
view
as_strided

View File

@ -29,7 +29,6 @@ project-excludes = [
"torch/_inductor/runtime/triton_heuristics.py",
"torch/_inductor/runtime/triton_helpers.py",
"torch/_inductor/runtime/halide_helpers.py",
"torch/utils/tensorboard/summary.py",
# formatting issues, will turn on after adjusting where suppressions can be
# in import statements
"tools/flight_recorder/components/types.py",
@ -47,7 +46,6 @@ project-excludes = [
"torch/distributed/elastic/metrics/__init__.py",
"torch/_inductor/fx_passes/bucketing.py",
# ====
"torch/onnx/_internal/exporter/_torchlib/ops/nn.py",
"torch/include/**",
"torch/csrc/**",
"torch/distributed/elastic/agent/server/api.py",

View File

@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None:
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
def mirror_inductor_external_kernels() -> None:
"""
Copy external kernels into Inductor so they are importable.
"""
paths = [
(
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
CWD
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
),
]
for new_path, orig_path in paths:
# Create the dirs involved in new_path if they don't exist
if not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location
if orig_path.is_file():
shutil.copyfile(orig_path, new_path)
continue
if orig_path.is_dir():
if new_path.exists():
# copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path)
continue
raise RuntimeError(
"Check the file paths in `mirror_inductor_external_kernels()`"
)
# ATTENTION: THIS IS AI SLOP
def extract_variant_from_version(version: str) -> str:
"""Extract variant from version string, defaulting to 'cpu'."""
@ -1647,8 +1616,6 @@ def main() -> None:
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
(
ext_modules,
cmdclass,
@ -1682,7 +1649,6 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_inductor/kernel/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -246,7 +246,7 @@ class TestSDPA(NNTestCase):
max_k,
philox_seed,
philox_offset,
_debug_attn_mask,
debug_attn_mask,
) = torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1
)
@ -256,7 +256,7 @@ class TestSDPA(NNTestCase):
)
rand_upward_privateuse1 = rand_upward.to("openreg")
grad_input_mask = [True, True, True, True]
_grad_q, _grad_k, _grad_v, _grad_attn_mask = (
grad_q, grad_k, grad_v, grad_attn_mask = (
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1,
q_privateuse1,

View File

@ -392,11 +392,11 @@ class ComposabilityTest(MultiProcessTestCase):
replicate_size = self.world_size // (pp_size)
device_mesh = init_device_mesh(
device_type,
mesh_shape=(replicate_size, pp_size),
mesh_dim_names=("replicate", "pp"),
mesh_shape=(replicate_size, 1, pp_size),
mesh_dim_names=("replicate", "shard", "pp"),
)
torch.manual_seed(42)
dp_mesh = device_mesh["replicate"]
dp_mesh = device_mesh["replicate", "shard"]
pp_mesh = device_mesh["pp"]
pp_group = device_mesh["pp"].get_group()
@ -416,13 +416,15 @@ class ComposabilityTest(MultiProcessTestCase):
param_dtype=MixedPrecisionParam,
reduce_dtype=torch.float32,
)
replicate_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
replicate_config = {"mp_policy": mp_policy}
for layer_id in range(len(partial_model)):
replicate(
partial_model[layer_id],
device_mesh=dp_mesh,
**replicate_config,
reshard_after_forward=False,
)
dp_model = replicate(partial_model, **replicate_config)
dp_model = replicate(partial_model, device_mesh=dp_mesh, **replicate_config)
return dp_model
# Apply same precision to reference model (without replicate)
@ -580,11 +582,11 @@ class ComposabilityTest(MultiProcessTestCase):
replicate_size = self.world_size // (pp_size)
device_mesh = init_device_mesh(
device_type,
mesh_shape=(replicate_size, pp_size),
mesh_dim_names=("replicate", "pp"),
mesh_shape=(replicate_size, 1, pp_size),
mesh_dim_names=("replicate", "shard", "pp"),
)
torch.manual_seed(42)
dp_mesh = device_mesh["replicate"]
dp_mesh = device_mesh["replicate", "shard"]
pp_mesh = device_mesh["pp"]
pp_group = device_mesh["pp"].get_group()
dp_group = device_mesh["replicate"].get_group()
@ -646,9 +648,10 @@ class ComposabilityTest(MultiProcessTestCase):
for layer_id in range(len(partial_model)):
replicate(
partial_model[layer_id],
mesh=dp_mesh,
device_mesh=dp_mesh,
reshard_after_forward=False,
)
dp_model = replicate(partial_model, mesh=dp_mesh)
dp_model = replicate(partial_model, device_mesh=dp_mesh)
return dp_model
def pipelined_models_parameters(start_layer, model):

View File

@ -3,7 +3,7 @@
import copy
import dataclasses
import functools
from typing import Optional
from typing import Optional, Union
import torch
import torch.distributed as dist
@ -14,6 +14,7 @@ from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
_get_gradient_divide_factors,
)
from torch.distributed.tensor import Shard
from torch.testing._internal.common_distributed import (
requires_nccl_version,
SaveForwardInputsModel,
@ -45,20 +46,35 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
def _init_models_and_optims(
self,
reshard_after_forward: Union[bool, int],
param_dtype: Optional[torch.dtype],
reduce_dtype: Optional[torch.dtype],
use_shard_placement_fn,
):
torch.manual_seed(42)
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = -1
largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
assert largest_dim >= 0, f"{param.shape}"
return Shard(largest_dim)
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype
)
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
replicate_fn = functools.partial(
replicate,
reshard_after_forward=reshard_after_forward,
mp_policy=mp_policy,
shard_placement_fn=shard_placement_fn,
)
for mlp in model:
replicate_fn(mlp)
@ -66,13 +82,27 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
return ref_model, ref_optim, model, optim
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
use_shard_placement_fn_vals = [False]
if self.world_size == 2:
# For world size >2, gradient elements get reduced in different
# orders for the baseline vs. dim-1 sharding, leading to numeric
# differences for bf16 reduction, so only test world size 2.
use_shard_placement_fn_vals.append(True)
return use_shard_placement_fn_vals
@skipIfRocmVersionLessThan((7, 0))
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_compute_dtype(self):
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"param_dtype": [torch.bfloat16, torch.float16],
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_compute_dtype,
)
@ -80,10 +110,14 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
def _test_compute_dtype(
self,
param_dtype: torch.dtype,
reshard_after_forward: Union[bool, int],
use_shard_placement_fn: bool,
):
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=None,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
@ -141,14 +175,39 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
@skip_if_lt_x_gpu(2)
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
def test_reduce_dtype(self):
self._test_reduce_dtype_fp32_reduce()
self._test_reduce_dtype_bf16_reduce()
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": [False, True],
},
self._test_reduce_dtype_fp32_reduce,
)
use_shard_placement_fn_vals = (
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
)
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_shard_placement_fn": use_shard_placement_fn_vals,
},
self._test_reduce_dtype_bf16_reduce,
)
def _test_reduce_dtype_fp32_reduce(self):
def _test_reduce_dtype_fp32_reduce(
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
):
if (
self.world_size > 2
and isinstance(reshard_after_forward, int)
and use_shard_placement_fn
):
return
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
orig_reduce_scatter = dist.reduce_scatter_tensor
@ -190,12 +249,14 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
check_sharded_parity(self, ref_model, model)
def _test_reduce_dtype_bf16_reduce(
self,
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
):
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
ref_model, ref_optim, model, optim = self._init_models_and_optims(
reshard_after_forward,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
use_shard_placement_fn=use_shard_placement_fn,
)
group = dist.distributed_c10d._get_default_group()
orig_reduce_scatter = dist.reduce_scatter_tensor
@ -260,8 +321,12 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for mlp in model:
replicate(mlp, mp_policy=mp_policy)
replicate(model, mp_policy=mp_policy)
replicate(
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
replicate(
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
orig_reduce_scatter = dist.reduce_scatter_tensor

View File

@ -108,70 +108,84 @@ class TestReplicateRegisteredParams(FSDPTestMultiThread):
"""Tests the parameter registration after forward."""
device = torch.device(device_type.type, 0)
# Single Replicate group
torch.manual_seed(42)
model = MLP(3, device)
# Since seed is per process, not per thread, we broadcast to ensure
# the same parameters across ranks
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model) # root only
inp = torch.randn((2, 3), device=device_type.type)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
self._assert_tensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
for reshard_after_forward in (True, False, None):
torch.manual_seed(42)
model = MLP(3, device)
# Since seed is per process, not per thread, we broadcast to ensure
# the same parameters across ranks
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 3), device=device_type.type)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
if reshard_after_forward:
self._assert_dtensor_params(model.parameters())
else:
self._assert_tensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
# Multiple Replicate groups
torch.manual_seed(42)
model = nn.Sequential(MLP(3, device), MLP(3, device))
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model[0].in_proj)
replicate(model[0].out_proj)
replicate(model)
for reshard_after_forward in (True, False, None):
torch.manual_seed(42)
model = nn.Sequential(MLP(3, device), MLP(3, device))
for param in model.parameters():
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
non_root_params = list(model[0].in_proj.parameters()) + list(
model[0].out_proj.parameters()
)
root_params = list(set(model.parameters()) - set(non_root_params))
self._assert_tensor_params(non_root_params)
self._assert_tensor_params(root_params)
self._assert_same_params(model.parameters(), ref_model.parameters())
for module in model.modules():
if isinstance(module, FSDPModule):
module.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp)
non_root_params = list(model[0].in_proj.parameters()) + list(
model[0].out_proj.parameters()
)
root_params = list(set(model.parameters()) - set(non_root_params))
if reshard_after_forward is None:
self._assert_dtensor_params(non_root_params)
self._assert_tensor_params(root_params)
elif reshard_after_forward:
self._assert_dtensor_params(non_root_params)
self._assert_dtensor_params(root_params)
else:
self._assert_tensor_params(non_root_params)
self._assert_tensor_params(root_params)
self._assert_same_params(model.parameters(), ref_model.parameters())
for module in model.modules():
if isinstance(module, FSDPModule):
module.reshard() # however, we can manually reshard
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
@skip_if_lt_x_gpu(1)
def test_param_registration_after_backward(self):
"""Tests the parameter registration after backward."""
device = torch.device(device_type.type, 0)
# Single Replicate group
model = MLP(8, device)
replicate(model) # root only
inp = torch.randn((2, 8), device=device_type.type)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
for reshard_after_forward in (True, False):
model = MLP(8, device)
replicate(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 8), device=device_type.type)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
# Multiple Replicate groups
model = MLP(8, device)
replicate(model.in_proj)
replicate(model.out_proj)
replicate(model)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
for reshard_after_forward in (True, False):
model = MLP(8, device)
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
# need to iterate over the list multiple times
@ -273,11 +287,14 @@ class TestReplicate1DTrainingCore(FSDPTest):
[(7, 15), (15, 3)],
[(16, 17), (17, 8)],
],
"use_shard_placement_fn": [False],
},
self._test_train_parity_single_group,
)
def _test_train_parity_single_group(self, lin_shapes: list[tuple[int, int]]):
def _test_train_parity_single_group(
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
):
torch.manual_seed(42)
model = nn.Sequential(
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
@ -316,6 +333,7 @@ class TestReplicate1DTrainingCore(FSDPTest):
"""
self.run_subtests(
{
"reshard_after_forward": [True, False],
"test_device_type": [device_type.type],
"offload_policy": [OffloadPolicy()],
"delay_after_forward": [False, True],
@ -336,6 +354,7 @@ class TestReplicate1DTrainingCore(FSDPTest):
"""
self.run_subtests(
{
"reshard_after_forward": [True], # save CI time
"offload_policy": [
CPUOffloadPolicy(pin_memory=True),
CPUOffloadPolicy(pin_memory=False),
@ -352,6 +371,7 @@ class TestReplicate1DTrainingCore(FSDPTest):
def _test_train_parity_multi_group(
self,
reshard_after_forward: Union[bool, int],
offload_policy: OffloadPolicy,
test_device_type: str,
delay_after_forward: bool,
@ -385,12 +405,13 @@ class TestReplicate1DTrainingCore(FSDPTest):
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
mesh = init_device_mesh(
test_device_type,
(self.world_size,),
mesh_dim_names=("replicate",),
(self.world_size, 1),
mesh_dim_names=("replicate", "shard"),
)
fully_shard_fn = functools.partial(
replicate,
mesh=mesh,
device_mesh=mesh,
reshard_after_forward=reshard_after_forward,
offload_policy=offload_policy,
)
for module in model.modules():
@ -506,10 +527,12 @@ class TestReplicate1DTrainingCore(FSDPTest):
Tests parity when running a module that participates multiple
times in forward.
"""
self.run_subtests(
{"reshard_after_forward": [True, False]},
self._test_multi_forward_module,
)
self._test_multi_forward_module()
def _test_multi_forward_module(self):
def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]):
class MultiForwardModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
@ -664,6 +687,7 @@ class TestReplicateTrainingCompose(FSDPTest):
"""
self.run_subtests(
{
"reshard_after_forward": [True, False],
"checkpoint_impl": ["composable", "utils", "wrapper"],
"module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"],
"test_device_type": [device_type.type],
@ -673,6 +697,7 @@ class TestReplicateTrainingCompose(FSDPTest):
def _test_train_parity_with_activation_checkpointing(
self,
reshard_after_forward: Union[bool, int],
checkpoint_impl: str,
module_grouping: str,
test_device_type: str,
@ -715,11 +740,12 @@ class TestReplicateTrainingCompose(FSDPTest):
# Apply Replicate
device_mesh = init_device_mesh(
test_device_type,
(self.world_size,),
mesh_dim_names=("replicate",),
(self.world_size, 1),
mesh_dim_names=("replicate", "shard"),
)
fsdp_kwargs = {
"mesh": device_mesh,
"reshard_after_forward": reshard_after_forward,
"device_mesh": device_mesh,
}
if module_grouping == "mem_eff":
assert model_args.n_layers == 3
@ -783,6 +809,7 @@ class TestReplicateSharedParams(FSDPTest):
def test_train_parity_with_shared_params(self):
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
},
self._test_train_shared_params,
@ -790,6 +817,7 @@ class TestReplicateSharedParams(FSDPTest):
def _test_train_shared_params(
self,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
):
torch.manual_seed(42)
@ -802,8 +830,8 @@ class TestReplicateSharedParams(FSDPTest):
if isinstance(module, TransformerBlock):
if use_activation_checkpointing:
checkpoint(module)
replicate(module)
replicate(model)
replicate(module, reshard_after_forward=reshard_after_forward)
replicate(model, reshard_after_forward=reshard_after_forward)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.manual_seed(42 + self.rank + 1)
@ -840,11 +868,11 @@ class TestReplicateGradientAccumulation(FSDPTest):
with/without resharding after backward.
"""
replicate_size = self.world_size
shard_size, replicate_size = 1, self.world_size
meshes = init_device_mesh(
device_type.type,
(replicate_size,),
mesh_dim_names=("replicate",),
(replicate_size, shard_size),
mesh_dim_names=("replicate", "shard"),
)
self.run_subtests(
{
@ -900,7 +928,8 @@ class TestReplicateGradientAccumulation(FSDPTest):
ref_model = copy.deepcopy(model).to(device_type)
replicate_fn = functools.partial(
replicate,
mesh=mesh,
device_mesh=mesh,
reshard_after_forward=reshard_after_forward,
offload_policy=offload_policy,
)
for mlp in model[1:]:
@ -1011,8 +1040,8 @@ class TestReplicateGradientAccumulation(FSDPTest):
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
replicate(module)
replicate(model)
replicate(module, reshard_after_forward=False)
replicate(model, reshard_after_forward=False)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
num_microbatches = 3
@ -1116,8 +1145,8 @@ class TestReplicateTPTraining(FSDPTest):
def init_global_mesh(self) -> DeviceMesh:
return init_device_mesh(
device_type.type,
(2, 2),
mesh_dim_names=("dp_replicate", "tp"),
(2, 1, 2),
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
)
@skip_if_lt_x_gpu(8)
@ -1125,6 +1154,7 @@ class TestReplicateTPTraining(FSDPTest):
global_mesh = self.init_global_mesh()
self.run_subtests(
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 5, 16, 17],
"foreach": [False],
@ -1135,11 +1165,12 @@ class TestReplicateTPTraining(FSDPTest):
def _test_replicate_tp(
self,
global_mesh: DeviceMesh,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
mlp_dim: int,
foreach: bool,
):
dp_mesh, tp_mesh = global_mesh["dp_replicate"], global_mesh["tp"]
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
torch.manual_seed(42)
@ -1166,8 +1197,8 @@ class TestReplicateTPTraining(FSDPTest):
continue
if use_activation_checkpointing:
checkpoint(module)
replicate(module, mesh=dp_mesh)
replicate(model, mesh=dp_mesh)
replicate(module, device_mesh=dp_mesh)
replicate(model, device_mesh=dp_mesh)
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
# strided-sharded layers.
@ -1198,9 +1229,11 @@ class TestReplicateTPTraining(FSDPTest):
for _, p in model.named_parameters():
self.assertIsInstance(p, DTensor)
self.assertEqual(p.device_mesh.ndim, 2)
self.assertEqual(len(p.placements), 2)
self.assertEqual(p.device_mesh.mesh_dim_names, ("dp_replicate", "tp"))
self.assertEqual(p.device_mesh.ndim, 3)
self.assertEqual(len(p.placements), 3)
self.assertEqual(
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
)
if __name__ == "__main__":

View File

@ -120,7 +120,7 @@ class ReplicateTest(MultiProcessTestCase):
if i % 2 == 0:
self.assertTrue("replicate" in _get_registry(layer))
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Replicate(),))
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
elif i % 2 == 1:
self.assertTrue("fully_shard" in _get_registry(layer))
for parameter in layer.parameters():
@ -197,14 +197,14 @@ class ReplicateTest(MultiProcessTestCase):
]
global_mesh = self.init_replicate_tp_mesh()
replicate_mesh = global_mesh["replicate"]
replicate_mesh = global_mesh["replicate", "shard"]
for layer in layers:
replicate(layer, mesh=replicate_mesh)
replicate(layer, device_mesh=replicate_mesh)
for parameter in layer.parameters():
self.assertEqual(parameter.device_mesh.shape, (2,))
self.assertEqual(parameter.placements, (Replicate(),))
self.assertEqual(parameter.device_mesh.shape, (2, 1))
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
@skip_if_lt_x_gpu(2)
def test_train_replicate_fsdp(self):
@ -263,6 +263,7 @@ class ReplicateTest(MultiProcessTestCase):
run_subtests(
self,
{
"reshard_after_forward": [False, True],
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 16, 17],
},
@ -272,6 +273,7 @@ class ReplicateTest(MultiProcessTestCase):
def _test_train_parity_2d_mlp(
self,
global_mesh: DeviceMesh,
reshard_after_forward: bool,
use_activation_checkpointing: bool,
mlp_dim: int,
):
@ -285,12 +287,13 @@ class ReplicateTest(MultiProcessTestCase):
torch.manual_seed(42)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, mesh=replicate_mesh)
replicate(ref_model, device_mesh=replicate_shard_mesh)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
model.parallelize(
tp_mesh,
replicate_shard_mesh,
use_activation_checkpointing,
reshard_after_forward=reshard_after_forward,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)

View File

@ -1,26 +1,16 @@
# Owner(s): ["oncall: distributed checkpointing"]
import os
import sys
from unittest.mock import patch
import torch
import torch.testing._internal.common_utils as common
from torch import distributed as dist
from torch.distributed.checkpoint._async_process_executor import (
_ProcessBasedAsyncCheckpointExecutor,
_ProcessGroupInitInfo,
)
from torch.distributed.checkpoint.api import CheckpointException
from torch.distributed.checkpoint.storage import StorageWriter
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_distributed import skip_if_win32
from torch.testing._internal.common_utils import (
retry_on_connect_failures,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
@ -120,184 +110,47 @@ class TestAsyncProcessExecutor(DTensorTestBase):
"epoch": 5,
}
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("DCP_USE_PREFIX_STORE", None)
# 1. Simulate a failure in creating PG in background process.
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=-1,
):
with self.assertRaises(ValueError) as _:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
)
fut.result()
# 2. Attempt save with failing storage writer
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=get_free_port(),
) as mock_get_free_port:
# 1. Simulate a failure in creating PG in background process.
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=-1,
):
with self.assertRaises(ValueError) as _:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="fail_once"),
)
self.assertIn(
"fail_once policy triggered failure", str(fut.exception())
)
# Verify new process was created for this attempt
if dist.get_rank() == 0:
mock_get_free_port.assert_called_once()
fut.result()
# 3. Second save attempt with successful storage writer - process should still be alive
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
) as mock_get_free_port:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="success"),
)
result = fut.result()
# Verify process is still alive
mock_get_free_port.assert_not_called()
# Verify successful save
self.assertIsNotNone(result)
# 2. Attempt save with failing storage writer
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
return_value=get_free_port(),
) as mock_get_free_port:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="fail_once"),
)
self.assertIn("fail_once policy triggered failure", str(fut.exception()))
# Verify new process was created for this attempt
if dist.get_rank() == 0:
mock_get_free_port.assert_called_once()
class TestAsyncProcessExecutorPrefixStore(TestCase):
@skip_if_win32()
@retry_on_connect_failures
def test_checkpoint_save_with_prefix_store_enabled(self) -> None:
"""Test that checkpoint save works when DCP_USE_PREFIX_STORE is enabled."""
test_state_dict = {
"model": {"weight": torch.randn(4, 4), "bias": torch.randn(4)},
"optimizer": {"param_groups": [{"lr": 0.01}]},
"epoch": 5,
}
master_addr = "localhost"
master_port = str(common.find_free_port())
with patch.dict(
os.environ,
{
"DCP_USE_PREFIX_STORE": "1",
"MASTER_ADDR": master_addr,
"MASTER_PORT": master_port,
},
):
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port"
) as mock_get_free_port:
dist.init_process_group(
backend=dist.Backend.GLOO,
rank=0,
world_size=1,
)
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="success"),
)
result = fut.result()
self.assertIsNotNone(result)
mock_get_free_port.assert_not_called()
class TestProcessGroupInitInfo(DTensorTestBase):
"""Test suite for _ProcessGroupInitInfo."""
@with_comms
def test_process_group_init_info_with_default_pg(self) -> None:
"""Test that ProcessGroupInitInfo correctly initializes."""
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("DCP_USE_PREFIX_STORE", None)
pg_init_info = _ProcessGroupInitInfo()
self.assertEqual(pg_init_info.global_rank, dist.get_rank())
self.assertEqual(pg_init_info.world_size, dist.get_world_size())
self.assertIsNotNone(pg_init_info.tcp_store_master_addr)
self.assertGreater(pg_init_info.tcp_store_master_port, 0)
self.assertEqual(pg_init_info.use_prefix_store, False)
@with_comms
def test_process_group_init_info_with_prefix_store_env_var(self) -> None:
"""Test that ProcessGroupInitInfo handles DCP_USE_PREFIX_STORE environment variable."""
# Flag enabled, addr/port correctly defined
with patch.dict(
os.environ,
{
"DCP_USE_PREFIX_STORE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
},
):
pg_init_info = _ProcessGroupInitInfo()
self.assertTrue(pg_init_info.use_prefix_store)
# Missing port
with patch.dict(
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_ADDR": "localhost"}
):
with self.assertRaises(CheckpointException):
pg_init_info = _ProcessGroupInitInfo()
# Missing addr
with patch.dict(
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_PORT": "12345"}
):
with self.assertRaises(CheckpointException):
pg_init_info = _ProcessGroupInitInfo()
# Invalid port
with patch.dict(
os.environ,
{
"DCP_USE_PREFIX_STORE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "a",
},
):
with self.assertRaises(CheckpointException):
pg_init_info = _ProcessGroupInitInfo()
@with_comms
def test_process_group_init_info_without_prefix_store_env_var(self) -> None:
"""Test that ProcessGroupInitInfo defaults to not using prefix store."""
# Env var set to 0
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "0"}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
# Missing env var
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("DCP_USE_PREFIX_STORE", None)
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
# Invalid env var
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "2"}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "true"}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "false"}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": ""}):
pg_init_info = _ProcessGroupInitInfo()
self.assertFalse(pg_init_info.use_prefix_store)
# 3. Second save attempt with successful storage writer - process should still be alive
with patch(
"torch.distributed.checkpoint._async_process_executor.get_free_port",
) as mock_get_free_port:
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
fut = proc_executor.execute_save(
staging_future_or_state_dict=test_state_dict,
storage_writer=TestStorageWriter(behavior="success"),
)
result = fut.result()
# Verify process is still alive
mock_get_free_port.assert_not_called()
# Verify successful save
self.assertIsNotNone(result)
if __name__ == "__main__":

View File

@ -415,15 +415,6 @@ class TestDTensorDebugMode(TestCase):
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
)
with DebugMode(record_stack_trace=True) as debug_mode:
out = mod(inp).sum()
out.backward()
sum_op = [
op for op in debug_mode.operators if str(op.op) == "aten.sum.dim_IntList"
][-1]
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -1019,28 +1019,6 @@ class DTensorMeshTest(DTensorTestBase):
except ValueError:
self.fail("Unexpected ValueError raised with run_check=False")
@with_comms
def test_as_strided_identity(self):
# Test calling as_strided with the same size/stride/offset as input tensor
# This should be a no-op but currently fails
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 4, device=self.device_type)
dtensor = DTensor.from_local(local_tensor, device_mesh, placements)
# Get the current size, stride, and storage_offset
size = dtensor.size()
stride = dtensor.stride()
storage_offset = dtensor.storage_offset()
# Call as_strided with the exact same parameters
result = dtensor.as_strided(size, stride, storage_offset)
# The result should be identical to the input
self.assertEqual(result.size(), dtensor.size())
self.assertEqual(result.stride(), dtensor.stride())
self.assertEqual(result.to_local(), dtensor.to_local())
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
DTensorMeshTest,

View File

@ -1,5 +1,5 @@
diff --git a/test/dynamo/cpython/3_13/test_heapq.py b/test/dynamo/cpython/3_13/test_heapq.py
index 1aa8e4e2897..bc177c2943e 100644
index 1aa8e4e2897..94315fa68b4 100644
--- a/test/dynamo/cpython/3_13/test_heapq.py
+++ b/test/dynamo/cpython/3_13/test_heapq.py
@@ -1,3 +1,23 @@
@ -35,7 +35,7 @@ index 1aa8e4e2897..bc177c2943e 100644
def test_py_functions(self):
for fname in func_names:
self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
@@ -27,24 +47,12 @@ class TestModules(TestCase):
@@ -27,24 +47,7 @@ class TestModules(TestCase):
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
@ -46,15 +46,12 @@ index 1aa8e4e2897..bc177c2943e 100644
- # However, doctest can't easily find all docstrings in the module (loading
- # it through import_fresh_module seems to confuse it), so we specifically
- # create a finder which returns the doctests from the merge method.
+@torch._dynamo.disable
+def randrange(*args):
+ return random.randrange(*args)
-
- class HeapqMergeDocTestFinder:
- def find(self, *args, **kwargs):
- dtf = doctest.DocTestFinder()
- return dtf.find(py_heapq.merge)
-
- tests.addTests(doctest.DocTestSuite(py_heapq,
- test_finder=HeapqMergeDocTestFinder()))
- return tests
@ -64,155 +61,7 @@ index 1aa8e4e2897..bc177c2943e 100644
def test_push_pop(self):
# 1) Push 256 random numbers and pop them off, verifying all's OK.
@@ -52,7 +60,8 @@ class TestHeap:
data = []
self.check_invariant(heap)
for i in range(256):
- item = random.random()
+ with torch._dynamo.error_on_graph_break(False):
+ item = random.random()
data.append(item)
self.module.heappush(heap, item)
self.check_invariant(heap)
@@ -83,14 +92,16 @@ class TestHeap:
def test_heapify(self):
for size in list(range(30)) + [20000]:
- heap = [random.random() for dummy in range(size)]
+ with torch._dynamo.error_on_graph_break(False):
+ heap = [random.random() for dummy in range(size)]
self.module.heapify(heap)
self.check_invariant(heap)
self.assertRaises(TypeError, self.module.heapify, None)
def test_naive_nbest(self):
- data = [random.randrange(2000) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [randrange(2000) for i in range(1000)]
heap = []
for item in data:
self.module.heappush(heap, item)
@@ -113,7 +124,8 @@ class TestHeap:
# heap instead of a min heap, it could go faster still via
# heapify'ing all of data (linear time), then doing 10 heappops
# (10 log-time steps).
- data = [random.randrange(2000) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
for item in data[10:]:
@@ -126,7 +138,8 @@ class TestHeap:
self.assertRaises(IndexError, self.module.heapreplace, [], None)
def test_nbest_with_pushpop(self):
- data = [random.randrange(2000) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
for item in data[10:]:
@@ -163,8 +176,9 @@ class TestHeap:
def test_heapsort(self):
# Exercise everything with repeated heapsort checks
for trial in range(100):
- size = random.randrange(50)
- data = [random.randrange(25) for i in range(size)]
+ with torch._dynamo.error_on_graph_break(False):
+ size = randrange(50)
+ data = [randrange(25) for i in range(size)]
if trial & 1: # Half of the time, use heapify
heap = data[:]
self.module.heapify(heap)
@@ -177,12 +191,13 @@ class TestHeap:
def test_merge(self):
inputs = []
- for i in range(random.randrange(25)):
- row = []
- for j in range(random.randrange(100)):
- tup = random.choice('ABC'), random.randrange(-500, 500)
- row.append(tup)
- inputs.append(row)
+ with torch._dynamo.error_on_graph_break(False):
+ for i in range(randrange(25)):
+ row = []
+ for j in range(randrange(100)):
+ tup = random.choice('ABC'), randrange(-500, 500)
+ row.append(tup)
+ inputs.append(row)
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
for reverse in [False, True]:
@@ -209,12 +224,14 @@ class TestHeap:
list(self.module.merge(iterable(), iterable()))
def test_merge_stability(self):
- class Int(int):
- pass
+ with torch._dynamo.error_on_graph_break(False):
+ class Int(int):
+ pass
inputs = [[], [], [], []]
for i in range(20000):
- stream = random.randrange(4)
- x = random.randrange(500)
+ with torch._dynamo.error_on_graph_break(False):
+ stream = randrange(4)
+ x = randrange(500)
obj = Int(x)
obj.pair = (x, stream)
inputs[stream].append(obj)
@@ -224,7 +241,8 @@ class TestHeap:
self.assertEqual(result, sorted(result))
def test_nsmallest(self):
- data = [(random.randrange(2000), i) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [(randrange(2000), i) for i in range(1000)]
for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(list(self.module.nsmallest(n, data)),
@@ -233,7 +251,8 @@ class TestHeap:
sorted(data, key=f)[:n])
def test_nlargest(self):
- data = [(random.randrange(2000), i) for i in range(1000)]
+ with torch._dynamo.error_on_graph_break(False):
+ data = [(randrange(2000), i) for i in range(1000)]
for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(list(self.module.nlargest(n, data)),
@@ -248,28 +267,29 @@ class TestHeap:
data = [comp(x) for x in data]
self.module.heapify(data)
return [self.module.heappop(data).x for i in range(len(data))]
- class LT:
- def __init__(self, x):
- self.x = x
- def __lt__(self, other):
- return self.x > other.x
- class LE:
- def __init__(self, x):
- self.x = x
- def __le__(self, other):
- return self.x >= other.x
- data = [random.random() for i in range(100)]
+ with torch._dynamo.error_on_graph_break(False):
+ class LT:
+ def __init__(self, x):
+ self.x = x
+ def __lt__(self, other):
+ return self.x > other.x
+ class LE:
+ def __init__(self, x):
+ self.x = x
+ def __le__(self, other):
+ return self.x >= other.x
+ data = [random.random() for i in range(100)]
target = sorted(data, reverse=True)
self.assertEqual(hsort(data, LT), target)
@@ -264,12 +267,12 @@ class TestHeap:
self.assertRaises(TypeError, data, LE)
@ -227,7 +76,7 @@ index 1aa8e4e2897..bc177c2943e 100644
module = c_heapq
@@ -374,7 +394,7 @@ class SideEffectLT:
@@ -374,7 +377,7 @@ class SideEffectLT:
return self.value < other.value
@ -236,48 +85,7 @@ index 1aa8e4e2897..bc177c2943e 100644
def test_non_sequence(self):
for f in (self.module.heapify, self.module.heappop):
@@ -435,10 +455,11 @@ class TestErrorHandling:
def test_comparison_operator_modifiying_heap(self):
# See bpo-39421: Strong references need to be taken
# when comparing objects as they can alter the heap
- class EvilClass(int):
- def __lt__(self, o):
- heap.clear()
- return NotImplemented
+ with torch._dynamo.error_on_graph_break(False):
+ class EvilClass(int):
+ def __lt__(self, o):
+ heap.clear()
+ return NotImplemented
heap = []
self.module.heappush(heap, EvilClass(0))
@@ -446,15 +467,16 @@ class TestErrorHandling:
def test_comparison_operator_modifiying_heap_two_heaps(self):
- class h(int):
- def __lt__(self, o):
- list2.clear()
- return NotImplemented
+ with torch._dynamo.error_on_graph_break(False):
+ class h(int):
+ def __lt__(self, o):
+ list2.clear()
+ return NotImplemented
- class g(int):
- def __lt__(self, o):
- list1.clear()
- return NotImplemented
+ class g(int):
+ def __lt__(self, o):
+ list1.clear()
+ return NotImplemented
list1, list2 = [], []
@@ -464,13 +486,13 @@ class TestErrorHandling:
@@ -464,13 +467,13 @@ class TestErrorHandling:
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))

View File

@ -47,11 +47,6 @@ class TestModules(__TestCase):
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
@torch._dynamo.disable
def randrange(*args):
return random.randrange(*args)
class _TestHeap:
def test_push_pop(self):
@ -60,8 +55,7 @@ class _TestHeap:
data = []
self.check_invariant(heap)
for i in range(256):
with torch._dynamo.error_on_graph_break(False):
item = random.random()
item = random.random()
data.append(item)
self.module.heappush(heap, item)
self.check_invariant(heap)
@ -92,16 +86,14 @@ class _TestHeap:
def test_heapify(self):
for size in list(range(30)) + [20000]:
with torch._dynamo.error_on_graph_break(False):
heap = [random.random() for dummy in range(size)]
heap = [random.random() for dummy in range(size)]
self.module.heapify(heap)
self.check_invariant(heap)
self.assertRaises(TypeError, self.module.heapify, None)
def test_naive_nbest(self):
with torch._dynamo.error_on_graph_break(False):
data = [randrange(2000) for i in range(1000)]
data = [random.randrange(2000) for i in range(1000)]
heap = []
for item in data:
self.module.heappush(heap, item)
@ -124,8 +116,7 @@ class _TestHeap:
# heap instead of a min heap, it could go faster still via
# heapify'ing all of data (linear time), then doing 10 heappops
# (10 log-time steps).
with torch._dynamo.error_on_graph_break(False):
data = [randrange(2000) for i in range(1000)]
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
for item in data[10:]:
@ -138,8 +129,7 @@ class _TestHeap:
self.assertRaises(IndexError, self.module.heapreplace, [], None)
def test_nbest_with_pushpop(self):
with torch._dynamo.error_on_graph_break(False):
data = [randrange(2000) for i in range(1000)]
data = [random.randrange(2000) for i in range(1000)]
heap = data[:10]
self.module.heapify(heap)
for item in data[10:]:
@ -176,9 +166,8 @@ class _TestHeap:
def test_heapsort(self):
# Exercise everything with repeated heapsort checks
for trial in range(100):
with torch._dynamo.error_on_graph_break(False):
size = randrange(50)
data = [randrange(25) for i in range(size)]
size = random.randrange(50)
data = [random.randrange(25) for i in range(size)]
if trial & 1: # Half of the time, use heapify
heap = data[:]
self.module.heapify(heap)
@ -191,13 +180,12 @@ class _TestHeap:
def test_merge(self):
inputs = []
with torch._dynamo.error_on_graph_break(False):
for i in range(randrange(25)):
row = []
for j in range(randrange(100)):
tup = random.choice('ABC'), randrange(-500, 500)
row.append(tup)
inputs.append(row)
for i in range(random.randrange(25)):
row = []
for j in range(random.randrange(100)):
tup = random.choice('ABC'), random.randrange(-500, 500)
row.append(tup)
inputs.append(row)
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
for reverse in [False, True]:
@ -224,14 +212,12 @@ class _TestHeap:
list(self.module.merge(iterable(), iterable()))
def test_merge_stability(self):
with torch._dynamo.error_on_graph_break(False):
class Int(int):
pass
class Int(int):
pass
inputs = [[], [], [], []]
for i in range(20000):
with torch._dynamo.error_on_graph_break(False):
stream = randrange(4)
x = randrange(500)
stream = random.randrange(4)
x = random.randrange(500)
obj = Int(x)
obj.pair = (x, stream)
inputs[stream].append(obj)
@ -241,8 +227,7 @@ class _TestHeap:
self.assertEqual(result, sorted(result))
def test_nsmallest(self):
with torch._dynamo.error_on_graph_break(False):
data = [(randrange(2000), i) for i in range(1000)]
data = [(random.randrange(2000), i) for i in range(1000)]
for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(list(self.module.nsmallest(n, data)),
@ -251,8 +236,7 @@ class _TestHeap:
sorted(data, key=f)[:n])
def test_nlargest(self):
with torch._dynamo.error_on_graph_break(False):
data = [(randrange(2000), i) for i in range(1000)]
data = [(random.randrange(2000), i) for i in range(1000)]
for f in (None, lambda x: x[0] * 547 % 2000):
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(list(self.module.nlargest(n, data)),
@ -267,18 +251,17 @@ class _TestHeap:
data = [comp(x) for x in data]
self.module.heapify(data)
return [self.module.heappop(data).x for i in range(len(data))]
with torch._dynamo.error_on_graph_break(False):
class LT:
def __init__(self, x):
self.x = x
def __lt__(self, other):
return self.x > other.x
class LE:
def __init__(self, x):
self.x = x
def __le__(self, other):
return self.x >= other.x
data = [random.random() for i in range(100)]
class LT:
def __init__(self, x):
self.x = x
def __lt__(self, other):
return self.x > other.x
class LE:
def __init__(self, x):
self.x = x
def __le__(self, other):
return self.x >= other.x
data = [random.random() for i in range(100)]
target = sorted(data, reverse=True)
self.assertEqual(hsort(data, LT), target)
self.assertRaises(TypeError, data, LE)
@ -455,11 +438,10 @@ class _TestErrorHandling:
def test_comparison_operator_modifiying_heap(self):
# See bpo-39421: Strong references need to be taken
# when comparing objects as they can alter the heap
with torch._dynamo.error_on_graph_break(False):
class EvilClass(int):
def __lt__(self, o):
heap.clear()
return NotImplemented
class EvilClass(int):
def __lt__(self, o):
heap.clear()
return NotImplemented
heap = []
self.module.heappush(heap, EvilClass(0))
@ -467,16 +449,15 @@ class _TestErrorHandling:
def test_comparison_operator_modifiying_heap_two_heaps(self):
with torch._dynamo.error_on_graph_break(False):
class h(int):
def __lt__(self, o):
list2.clear()
return NotImplemented
class h(int):
def __lt__(self, o):
list2.clear()
return NotImplemented
class g(int):
def __lt__(self, o):
list1.clear()
return NotImplemented
class g(int):
def __lt__(self, o):
list1.clear()
return NotImplemented
list1, list2 = [], []

View File

@ -427,29 +427,17 @@ from user code:
optree.tree_flatten_with_path(d)
return torch.sin(x)
def post_munge(s):
s = re.sub(
r"optree\.\S*\.flatten_with_path",
"optree.<path>.flatten_with_path",
s,
)
return re.sub(
r"qualname: \S*flatten_with_path",
"qualname: <path>.flatten_with_path",
s,
)
fn(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
self.assertExpectedInline(
post_munge(first_graph_break),
first_graph_break,
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)

View File

@ -5241,63 +5241,6 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
x = torch.randn(1)
self.assertEqual(opt_mod(x), x + 1)
def test_full_with_tensor_fill_value(self):
"""Test that torch.full works correctly with dynamic tensor fill_value"""
# Test with tensor fill_value (the bug case)
def func_tensor(x):
return torch.full((2,), x, dtype=torch.float64)
func_compiled = torch.compile(func_tensor)
# Test with different values
x1 = torch.tensor(5.0, dtype=torch.float64)
x2 = torch.tensor(10.0, dtype=torch.float64)
result1 = func_compiled(x1)
expected1 = torch.full((2,), x1, dtype=torch.float64)
self.assertEqual(result1, expected1)
# This is where the bug occurred - second call reused first value
result2 = func_compiled(x2)
expected2 = torch.full((2,), x2, dtype=torch.float64)
self.assertEqual(result2, expected2)
# Test with different dtypes
for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]:
def func_typed(x):
return torch.full((3,), x, dtype=dtype)
func_typed_compiled = torch.compile(func_typed)
x_typed = torch.tensor(7, dtype=dtype)
result = func_typed_compiled(x_typed)
expected = torch.full((3,), x_typed, dtype=dtype)
self.assertEqual(result, expected)
# Test with non-tensor fill_value (scalar) to ensure we didn't break existing behavior
def func_scalar(size):
return torch.full((size,), 42.0, dtype=torch.float32)
func_scalar_compiled = torch.compile(func_scalar)
result_scalar = func_scalar_compiled(5)
expected_scalar = torch.full((5,), 42.0, dtype=torch.float32)
self.assertEqual(result_scalar, expected_scalar)
# Test with different scalar values
def func_scalar_param():
# Test multiple calls with different hardcoded scalar values
a = torch.full((2,), 3.14, dtype=torch.float32)
b = torch.full((2,), 2.71, dtype=torch.float32)
return a, b
func_scalar_param_compiled = torch.compile(func_scalar_param)
result_a, result_b = func_scalar_param_compiled()
self.assertEqual(result_a, torch.full((2,), 3.14, dtype=torch.float32))
self.assertEqual(result_b, torch.full((2,), 2.71, dtype=torch.float32))
instantiate_parametrized_tests(FunctionTests)
instantiate_parametrized_tests(DefaultsTests)

View File

@ -69,7 +69,6 @@ from torch.fx.experimental.symbolic_shapes import (
constrain_unify,
ConstraintViolationError,
expect_true,
guard_or_false,
guard_size_oblivious,
ShapeEnv,
)
@ -101,6 +100,7 @@ from torch.testing._internal.common_utils import (
wrapDeterministicFlagAPITest,
)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.logging_utils import logs_to_string
pytree_modules = {
@ -13636,74 +13636,6 @@ instantiate_device_type_tests(
)
class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
def test_symbool_tensor_mul(self):
def symbool_mul_fn(x_bool, sentinel):
result = x_bool * sentinel
return result
x_true = torch.tensor([True], device="cuda")
x_false = torch.tensor([False], device="cuda")
sentinel = torch.tensor(2.0, requires_grad=True, device="cuda")
eager_result_true = symbool_mul_fn(x_true, sentinel)
eager_result_false = symbool_mul_fn(x_false, sentinel)
compiled_fn = torch.compile(symbool_mul_fn, fullgraph=True, dynamic=True)
compiled_result_true = compiled_fn(x_true, sentinel)
compiled_result_false = compiled_fn(x_false, sentinel)
self.assertEqual(eager_result_true, compiled_result_true)
self.assertEqual(eager_result_false, compiled_result_false)
self.assertEqual(compiled_result_true.item(), 2.0)
self.assertEqual(compiled_result_false.item(), 0.0)
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
def test_symbool_guard_or_false(self):
def symbool_guard_fn(a_bool_tensor, b):
u0 = a_bool_tensor.item()
# Make sure guard_or_false still handles SymBool produced by .item()
if guard_or_false(u0):
return b * 10
else:
return b * 100
compiled_guard_fn = torch.compile(
symbool_guard_fn, backend="eager", dynamic=True
)
a_true = torch.tensor(True, device="cuda")
a_false = torch.tensor(False, device="cuda")
b = torch.randn(6, device="cuda")
eager_res_true = symbool_guard_fn(a_true, b)
compiled_res_true = compiled_guard_fn(a_true, b)
self.assertEqual(eager_res_true, compiled_res_true)
eager_res_false = symbool_guard_fn(a_false, b)
compiled_res_false = compiled_guard_fn(a_false, b)
self.assertEqual(eager_res_false, compiled_res_false)
self.assertEqual(compiled_res_true, b * 10)
self.assertEqual(compiled_res_false, b * 100)
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
def test_symbool_tensor_mul_does_not_fail(self):
def fuzzed_program(arg_0, sentinel):
var_node_2 = arg_0
var_node_1 = torch.squeeze(var_node_2)
var_node_0 = var_node_1.item()
result = var_node_0 * sentinel
if result.is_complex():
result = result.real
return result
sentinel = torch.tensor(1.0, requires_grad=True, device="cuda")
arg_0 = torch.tensor([True], dtype=torch.bool, device="cuda")
args = (arg_0,) + (sentinel,)
try:
compiled_program = torch.compile(
fuzzed_program, fullgraph=True, dynamic=True
)
compiled_program(*args)
except Exception as e:
self.fail(f"torch.compile failed with error: {e}")
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -48,6 +48,7 @@ from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
expectedFailureDynamic,
rand_strided,
same,
skipIfNotPy312,
@ -1000,18 +1001,6 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.exit_stack.close()
super().tearDown()
def test_compiled_module_truthiness(self):
# Test with empty ModuleList
original_empty = nn.ModuleList()
compiled_empty = torch.compile(original_empty)
self.assertEqual(bool(original_empty), bool(compiled_empty))
self.assertFalse(bool(compiled_empty))
# Test with non-empty ModuleList
original_filled = nn.ModuleList([nn.Linear(10, 5)])
compiled_filled = torch.compile(original_filled)
self.assertEqual(bool(original_filled), bool(compiled_filled))
self.assertTrue(bool(compiled_filled))
def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder):
root = guard_manager_wrapper.root
cloned_root = root.clone_manager(lambda x: True)
@ -7455,6 +7444,93 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
msg,
)
@expectedFailureDynamic
def test_dynamo_default_lru_cache_behavior(self):
@torch.compile(backend="eager")
def fn(x):
return x + 10
torch._dynamo.reset()
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
# Step 1: Compile a static shapes graph
x = torch.randn(10, 10)
fn(x)
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(a), 1)
static_shapes_cache_entry = a[0]
# Step 2: Compile a dynamic shapes graph
y = torch.randn(20, 20)
fn(y)
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(b), 2)
self.assertEqual(b[1], static_shapes_cache_entry)
dynamic_shapes_cache_entry = b[0]
# Step 3: Run with Step 1's inputs
# LRU cache will match against dynamic shape graph first
fn(x)
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(c), 2)
self.assertEqual(c[0], dynamic_shapes_cache_entry)
self.assertEqual(c[1], static_shapes_cache_entry)
@expectedFailureDynamic
def test_dynamo_disable_lru_cache_behavior(self):
@torch.compile(backend="eager")
def fn(x):
return x + 10
def run():
torch._dynamo.reset()
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
# Step 1: Compile a static shapes graph
x = torch.randn(10, 10)
fn(x)
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(a), 1)
static_shapes_cache_entry = a[0]
# Step 2: Compile a dynamic shapes graph
y = torch.randn(20, 20)
fn(y)
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(b), 2)
self.assertEqual(b[0], static_shapes_cache_entry)
dynamic_shapes_cache_entry = b[1]
# Step 3: Run with Step 1's inputs
# LRU cache is disabled, we should still have static entry first
fn(x)
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(c), 2)
self.assertEqual(c[0], static_shapes_cache_entry)
self.assertEqual(c[1], dynamic_shapes_cache_entry)
try:
torch._C._dynamo.eval_frame._set_lru_cache(False)
run()
finally:
torch._C._dynamo.eval_frame._set_lru_cache(True)
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def test_sub_alpha_scalar_repro(self, device):

View File

@ -1089,8 +1089,6 @@ aten::rand.names
aten::rand.names_out
aten::rand.out
aten::rand_like
aten::rand_like.generator
aten::rand_like.generator_out
aten::rand_like.out
aten::randint
aten::randint.generator
@ -1102,15 +1100,9 @@ aten::randint.low_out
aten::randint.out
aten::randint_like
aten::randint_like.Tensor
aten::randint_like.Tensor_generator
aten::randint_like.Tensor_generator_out
aten::randint_like.Tensor_out
aten::randint_like.generator
aten::randint_like.generator_out
aten::randint_like.low_dtype
aten::randint_like.low_dtype_out
aten::randint_like.low_generator_dtype
aten::randint_like.low_generator_dtype_out
aten::randint_like.out
aten::randn.generator
aten::randn.generator_with_names
@ -1118,8 +1110,6 @@ aten::randn.generator_with_names_out
aten::randn.names
aten::randn.names_out
aten::randn_like
aten::randn_like.generator
aten::randn_like.generator_out
aten::randn_like.out
aten::random
aten::random.from

View File

@ -522,83 +522,6 @@ def forward(self, args_0):
)
self.assertEqual(ep(*inps), MyModel()(*inps))
def test_dynamo_graph_capture_full_tracing_context(self) -> None:
class Foo(torch.nn.Module):
def forward(self, x):
return x + x.shape[0]
foo = Foo()
def make_inputs(b: int):
ret = (torch.randn(b, 3),)
torch._dynamo.mark_dynamic(ret[0], 0)
return ret
trace_inputs = make_inputs(2)
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
test_inputs = make_inputs(3)
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
self.assertIsNotNone(gm.meta["tracing_context"].fake_mode)
self.assertEqual(len(gm.meta["tracing_context"].tensor_to_context), 1)
def test_dynamo_graph_capture_dict_keys_getitem(self):
class Module(torch.nn.Module):
def forward(self, x):
return x * 2
foo = Module()
class BlockMask:
def __init__(self, d):
self.d = d
block_mask = BlockMask(torch.randn(4))
def pre_hook_function(m, input):
block_mask.d = input[0] + 1
return input # Return a tuple of modified inputs
foo.register_forward_pre_hook(pre_hook_function)
def make_inputs():
return (torch.randn(4),)
trace_inputs = make_inputs()
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
test_inputs = make_inputs()
self.assertExpectedInline(
gm.code.strip("\r\n "),
"""\
def forward(self, args_0):
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
l_args_0_ = L_args_0_
add = l_args_0_ + 1
mul = l_args_0_ * 2; l_args_0_ = None
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""",
)
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
def test_dynamo_graph_capture_with_tensor_constant(self):
outer = torch.randn(2, 3)
class MyModel(torch.nn.Module):
def forward(self, x):
z = x + outer
return z
foo = MyModel()
def make_inputs():
return (torch.randn(2, 3),)
trace_inputs = make_inputs()
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
test_inputs = make_inputs()
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
class DummyOp(torch.autograd.Function):

View File

@ -7356,6 +7356,7 @@ metadata incorrectly.
aot_eager = torch.compile(backend="aot_eager")(fn)(x)
self.assertEqual(eager, aot_eager, atol=0, rtol=0)
@unittest.expectedFailure
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_rms_norm(self):
# Only CUDA rms norm fails to be decomposed

View File

@ -751,29 +751,6 @@ class TestConstFold(TestCase):
)
self.assertIsNone(mod_folded.const_subgraph_module)
def test_const_fold_partial_graph(self):
"""
If a model graph is partially const folded,
the non-const subgraph should be inlined back and erased.
"""
class TestModule(torch.nn.Module):
def __init__(self, p):
super().__init__()
self.p = p
def forward(self, x):
probs = torch.empty_permuted(x.shape, [0, 1])
mask = torch.bernoulli(probs, 1 - self.p)
return x * mask / (1 - self.p)
ep = torch.export.export(TestModule(0.4), (torch.randn(5, 10),))
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
ep.module(), device_for_folded_attrs="cpu"
)
self._verify_const_fold_mod(mod_folded)
if __name__ == "__main__":
raise_on_run_directly("test/test_fx.py")

View File

@ -20,14 +20,8 @@ from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
skipIf,
skipXPUIf,
)
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
TEST_WITH_SLOW,
TestCase,
)
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
from torch.testing._internal.inductor_utils import IS_BIG_GPU
@ -388,11 +382,7 @@ class TestAnalysis(TestCase):
verify_triton(comp_omni)
@skipIf(
(not torch.xpu.is_available()) and (not SM80OrLater),
"Requires XPU or CUDA SM80",
)
@skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU")
@skipIf(not SM80OrLater, "Requires SM80")
@dtypes(torch.float, torch.float16)
@parametrize(
"maxat",
@ -477,7 +467,6 @@ class TestAnalysis(TestCase):
"aten::cudnn_convolution",
"aten::convolution",
"aten::_convolution",
"aten::convolution_overrideable",
)
)
or "conv" in name

View File

@ -4,7 +4,6 @@ import os
import tempfile
from threading import Event
import torch._inductor.config as config
from torch._inductor.compile_worker.subproc_pool import (
raise_testexc,
SubprocException,
@ -17,12 +16,9 @@ from torch.testing._internal.inductor_utils import HAS_CPU
class TestCompileWorker(TestCase):
def make_pool(self, size):
return SubprocPool(size)
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_basic_jobs(self):
pool = self.make_pool(2)
pool = SubprocPool(2)
try:
a = pool.submit(operator.add, 100, 1)
b = pool.submit(operator.sub, 100, 1)
@ -33,7 +29,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_exception(self):
pool = self.make_pool(2)
pool = SubprocPool(2)
try:
a = pool.submit(raise_testexc)
with self.assertRaisesRegex(
@ -46,7 +42,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_crash(self):
pool = self.make_pool(2)
pool = SubprocPool(2)
try:
with self.assertRaises(Exception):
a = pool.submit(os._exit, 1)
@ -62,7 +58,7 @@ class TestCompileWorker(TestCase):
@skipIfWindows(msg="pass_fds not supported on Windows.")
def test_quiesce(self):
pool = self.make_pool(2)
pool = SubprocPool(2)
try:
a = pool.submit(operator.add, 100, 1)
pool.quiesce()
@ -79,7 +75,7 @@ class TestCompileWorker(TestCase):
os.environ["ROLE_RANK"] = "0"
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
pool = self.make_pool(2)
pool = SubprocPool(2)
try:
pool.submit(operator.add, 100, 1)
self.assertEqual(os.path.exists(temp_log.name), True)
@ -87,12 +83,6 @@ class TestCompileWorker(TestCase):
pool.shutdown()
@config.patch("quiesce_async_compile_time", 0.1)
class TestCompileWorkerWithTimer(TestCompileWorker):
def make_pool(self, size):
return SubprocPool(size, quiesce=True)
class TestTimer(TestCase):
def test_basics(self):
done = Event()

View File

@ -1,154 +0,0 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch import Tensor
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import ensure_cute_available
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
@unittest.skipIf(
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
"CuTeDSL library or Blackwell device not available",
)
@instantiate_parametrized_tests
class TestCuTeDSLGroupedGemm(InductorTestCase):
def _get_inputs(
self,
group_size: int,
M_hint: int,
K: int,
N: int,
device: str,
dtype: torch.dtype,
alignment: int = 16,
) -> tuple[Tensor, Tensor, Tensor]:
# --- Random, tile-aligned M sizes ---
M_sizes = (
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
* alignment
)
M_total = torch.sum(M_sizes).item()
# --- Construct input tensors ---
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
# --- Build offsets (no leading zero, strictly increasing) ---
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
return (A, B, offsets)
@parametrize("group_size", (2, 8))
@parametrize("M_hint", (256, 1024))
@parametrize("K", (64, 128))
@parametrize("N", (128, 256))
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
device = "cuda"
dtype = torch.bfloat16
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# Eager execution
c_eager = grouped_gemm_fn(A, B, offsets)
# Test with Cute backend
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
@parametrize("layout_B", ("contiguous", "broadcasted"))
def test_grouped_gemm_assorted_layouts(
self,
layout_A: str,
layout_B: str,
):
device = "cuda"
dtype = torch.bfloat16
G, K, N = 8, 64, 128
M_sizes = [128] * G
sum_M = sum(M_sizes)
offsets = torch.tensor(
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
)
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
A = A_base
if layout_A == "offset":
# allocate bigger buffer than needed, use nonzero storage offset
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
offset = 128 # skip first 128 elements
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
elif layout_A == "padded":
# simulate row pitch > K (row_stride = K + pad)
row_pitch = K + 8
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
elif layout_A == "view":
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
A = A_storage.view(sum_M, K)
assert A._base is not None
assert A.shape == (sum_M, K)
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
if layout_B == "broadcasted":
# Broadcast B across groups (zero stride along G)
B = B[0].expand(G, K, N)
assert B.stride(0) == 0
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# --- eager ---
c_eager = grouped_gemm_fn(A, B, offsets)
# --- compiled (CUTE backend) ---
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
if __name__ == "__main__":
run_tests()

View File

@ -15,8 +15,9 @@ from torch.testing._internal.common_utils import (
is_navi3_arch,
parametrize,
patch_test_members,
TEST_XPU,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU_AND_TRITON
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_gpu
@ -60,6 +61,11 @@ class TestDecomposeAddMM(torch.nn.Module):
@requires_gpu
@unittest.skipIf(
TEST_XPU,
"Intel GPU has not enabled decompose_mem_bound_mm PASS in "
"torch/_inductor/fx_passes/decompose_mem_bound_mm.py",
)
@torch._inductor.config.patch(
post_grad_fusion_options={
"decompose_mm_pass": {},
@ -138,7 +144,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_bmm"],
expected_val,
@ -149,7 +155,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
expected_val = 3 if should_decompose and HAS_GPU_AND_TRITON else 0
expected_val = 3 if should_decompose and HAS_CUDA_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_bmm"],
expected_val,
@ -198,7 +204,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
if has_bias:
self.assertEqual(
counters["inductor"]["decompose_addmm"],
@ -253,7 +259,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
if has_bias:
self.assertEqual(
counters["inductor"]["decompose_addmm"],
@ -298,7 +304,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"],
expected_val,
@ -310,7 +316,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
expected_val,
@ -368,7 +374,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"],
expected_val,
@ -380,7 +386,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
self.assertEqual(
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
expected_val,
@ -404,7 +410,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_pred(module, traced, input)
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
if has_bias:
self.assertEqual(
counters["inductor"]["decompose_addmm"],
@ -418,7 +424,7 @@ class TestDecomposeMemMM(TestCase):
self.compare_gradients(module, traced)
expected_val = 0
if HAS_GPU_AND_TRITON:
if HAS_CUDA_AND_TRITON:
expected_val = 1 if has_bias else 2
self.assertEqual(
@ -441,8 +447,12 @@ class TestDecomposeMemMM(TestCase):
_, code = run_and_get_code(foo, input1, input2)
# two kernels generated
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
if GPU_TYPE == "xpu":
# only 1 kernel generated on the XPU stack
FileCheck().check_count(".run(", 1, exactly=True).run(code[0])
else:
# two kernels generated
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
def test_check_device(self):
m = 5
@ -452,7 +462,7 @@ class TestDecomposeMemMM(TestCase):
input1 = torch.randn(m, k, device=GPU_TYPE)
input2 = torch.randn(k, n, device=GPU_TYPE)
self.assertTrue(check_device(input1, input2, device=GPU_TYPE))
self.assertTrue(check_device(input1, input2))
self.assertFalse(check_device(input1, input2, device="cpu"))
input1 = torch.randn(m, k)

View File

@ -794,16 +794,14 @@ class TestFP8Lowering(TestCase):
_get_torch_cuda_version() < (12, 9),
"cuBLAS blockwise scaling added in CUDA 12.9",
)
@parametrize("shape", ((16, 256, 256), (1024, 512, 1024)))
@parametrize("use_fast_accum", (False, True))
@parametrize(
"scaling_block_sizes", ((1, 128, 128, 128), (1, 128, 1, 128))
) # (BlockWise1x128, BlockWise128x128), (BlockWise1x128, BlockWise1x128)
def test_main_loop_scaling(
"shape", ((16, 256, 256), (1024, 512, 1024))
) # TODO (jananisriram): add scaling recipe overrides for shapes like (16, 256, 64) and (256, 16, 64)
@parametrize("use_fast_accum", (False, True))
def test_blockwise1x128_blockwise128x128_scaling(
self,
shape: tuple[int, int, int],
use_fast_accum: bool,
scaling_block_sizes: tuple[int, int, int, int],
):
# Only bf16 output type is supported for non-tensorwise scaling, not fp32
dtype: torch.dtype = torch.bfloat16
@ -816,28 +814,20 @@ class TestFP8Lowering(TestCase):
w = torch.randn(N, K, dtype=dtype, device=device)
bias = None
am, ak, bn, bk = scaling_block_sizes
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_blockwise(
w, dtype_float8, block_outer=bn, block_inner=bk
w, dtype_float8, block_outer=128, block_inner=128
)
w_t_fp8 = w_fp8.t()
if (bn, bk) == (1, 128):
w_inverse_scale = (
w_inverse_scale.t().contiguous().t().t()
) # 1x128 blocks need scales to be outer-dim-major
else:
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
# quantize input x
x_fp8, x_inverse_scale = _quantize_blockwise(
x, dtype_float8, block_outer=am, block_inner=ak
x, dtype_float8, block_outer=1, block_inner=128
)
if (am, ak) == (1, 128):
x_inverse_scale = (
x_inverse_scale.t().contiguous().t()
) # 1x128 blocks need scales to be outer-dim-major
x_inverse_scale = (
x_inverse_scale.t().contiguous().t()
) # 1x128 blocks need scales to be outer-dim-major
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
@ -882,15 +872,9 @@ class TestFP8Lowering(TestCase):
FileCheck().check(
f"SCALE_RECIPE_A : tl.constexpr = {ScalingType.BlockWise1x128.value}"
).run(code[0])
if (bn, bk) == (1, 128):
check_scale_recipe_b = ScalingType.BlockWise1x128.value
else:
check_scale_recipe_b = ScalingType.BlockWise128x128.value
FileCheck().check(
f"SCALE_RECIPE_B : tl.constexpr = {check_scale_recipe_b}"
f"SCALE_RECIPE_B : tl.constexpr = {ScalingType.BlockWise128x128.value}"
).run(code[0])
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)

View File

@ -806,6 +806,8 @@ class AOTFxirTestCase(InductorTestCase):
def check(
self, model, inp, dynamic_shapes=None, strict=False
) -> torch.fx.GraphModule:
if self.device == "xpu":
raise unittest.SkipTest("The feature AOTFxir not currently ready for XPU")
with torch.no_grad():
ep = torch.export.export(
model, inp, dynamic_shapes=dynamic_shapes, strict=strict

Some files were not shown because too many files have changed in this diff Show More