mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
Compare commits
71 Commits
gh/mikayla
...
cpp-docs-d
| Author | SHA1 | Date | |
|---|---|---|---|
| df1268c311 | |||
| 84f9f1541d | |||
| 27c0c126bf | |||
| 670873155a | |||
| 923737c510 | |||
| 13d5b14a73 | |||
| a35a42b21c | |||
| 15956bc1e8 | |||
| b319ea1111 | |||
| ce4c68a5f6 | |||
| c6da4a59a3 | |||
| 53f75cd5ba | |||
| 527b1109a8 | |||
| 3144713325 | |||
| eefa16342c | |||
| d02f68f484 | |||
| 68eb55c4b2 | |||
| 8d4b8ab430 | |||
| afd50bdd29 | |||
| 56dfd4c74b | |||
| 24db5c4451 | |||
| cc8bfd1206 | |||
| c45b156605 | |||
| 8fff7e36b4 | |||
| 82fa2aa269 | |||
| 09e0285608 | |||
| d980d8dc79 | |||
| c7d00de115 | |||
| d3cf90ada5 | |||
| 0e1a88904f | |||
| 3232caa078 | |||
| a6c6acea9d | |||
| 55be1cc739 | |||
| 344cebda52 | |||
| ba72c6b981 | |||
| 888efcc453 | |||
| 24aa9a2ef7 | |||
| f70faf2b9a | |||
| 167e64ba1a | |||
| 875b18d53c | |||
| eec3749c44 | |||
| 40133fe966 | |||
| f288433d3e | |||
| 864633fca0 | |||
| c21868b435 | |||
| a0a8eca01a | |||
| 0958f307d9 | |||
| 7551507c41 | |||
| f92834d477 | |||
| e1fc01bef8 | |||
| 22a745737a | |||
| ee708ea96c | |||
| 64819e3701 | |||
| 79ff2c66c8 | |||
| 665a411351 | |||
| 5c89bdb461 | |||
| 7b64ad906c | |||
| d944279def | |||
| 5048e4701d | |||
| 616314cfd5 | |||
| 2b7e4c3ef2 | |||
| 6c98657239 | |||
| 86b2d82e84 | |||
| eea8ff2d34 | |||
| 11f73d78c8 | |||
| 7d1b976146 | |||
| 27cfdd9e77 | |||
| 01d8d8584b | |||
| b8855e7b0b | |||
| 6725ee89c8 | |||
| 3a38ec78e1 |
@ -1,15 +1,11 @@
|
||||
sphinx==5.3.0
|
||||
sphinx==7.2.6
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
#Pinned versions: 7.2.6
|
||||
|
||||
standard-imghdr==3.13.0; python_version >= "3.13"
|
||||
#Description: This is needed by Sphinx, so it needs to be added here.
|
||||
# The reasons are as follows:
|
||||
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
|
||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||
pytorch_sphinx_theme2==0.2.0
|
||||
#Description: This is needed to generate PyTorch docs
|
||||
#Pinned versions: 0.2.0
|
||||
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
# something related to Docker setup. We can investigate this later.
|
||||
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 2.13.0
|
||||
|
||||
breathe==4.34.0
|
||||
breathe==4.36.0
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 4.34.0
|
||||
#Pinned versions: 4.36.0
|
||||
|
||||
exhale==0.2.3
|
||||
exhale==0.3.7
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.2.3
|
||||
#Pinned versions: 0.3.7
|
||||
|
||||
docutils==0.16
|
||||
docutils==0.20
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.16
|
||||
#Pinned versions: 0.20
|
||||
|
||||
bs4==0.0.1
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
@ -56,13 +52,13 @@ IPython==8.12.0
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
myst-nb==1.3.0
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
#Pinned versions: 1.3.0
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
sphinx-copybutton==0.5.0
|
||||
sphinx-design==0.4.0
|
||||
sphinx-design==0.6.1
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==0.18.1
|
||||
myst-parser==4.0.1
|
||||
|
||||
@ -89,23 +89,39 @@ if [ "$is_main_doc" = true ]; then
|
||||
|
||||
make coverage
|
||||
# Now we have the coverage report, we need to make sure it is empty.
|
||||
# Count the number of lines in the file and turn that number into a variable
|
||||
# $lines. The `cut -f1 ...` is to only parse the number, not the filename
|
||||
# Skip the report header by subtracting 2: the header will be output even if
|
||||
# there are no undocumented items.
|
||||
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
|
||||
# showing the undocumented count in the third column.
|
||||
# Example: | TOTAL | 99.83% | 2 |
|
||||
#
|
||||
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
|
||||
# be documented then removed from there.
|
||||
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
|
||||
undocumented=$((lines - 2))
|
||||
if [ $undocumented -lt 0 ]; then
|
||||
|
||||
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
|
||||
# The table format is: | Module | Coverage | Undocumented |
|
||||
# Extract the third column (undocumented count) from the TOTAL row
|
||||
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
|
||||
|
||||
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
|
||||
echo coverage output not found
|
||||
exit 1
|
||||
elif [ $undocumented -gt 0 ]; then
|
||||
echo undocumented objects found:
|
||||
cat build/coverage/python.txt
|
||||
elif [ "$undocumented" -gt 0 ]; then
|
||||
echo ""
|
||||
echo "====================="
|
||||
echo "UNDOCUMENTED OBJECTS:"
|
||||
echo "====================="
|
||||
echo ""
|
||||
# Find the line number of the TOTAL row and print only what comes after it
|
||||
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
|
||||
if [ -n "$total_line" ]; then
|
||||
# Print only the detailed list (skip the statistics table)
|
||||
tail -n +$((total_line + 2)) build/coverage/python.txt
|
||||
else
|
||||
# Fallback to showing entire file if TOTAL line not found
|
||||
cat build/coverage/python.txt
|
||||
fi
|
||||
echo ""
|
||||
echo "Make sure you've updated relevant .rsts in docs/source!"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
|
||||
@ -337,7 +337,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
@ -1653,7 +1653,7 @@ test_operator_microbenchmark() {
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/operator_benchmark
|
||||
|
||||
for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do
|
||||
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do
|
||||
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
|
||||
--benchmark-name "PyTorch operator microbenchmark" --use-compile
|
||||
|
||||
@ -28,7 +28,7 @@ CUDA_ARCHES_FULL_VERSION = {
|
||||
"12.6": "12.6.3",
|
||||
"12.8": "12.8.1",
|
||||
"12.9": "12.9.1",
|
||||
"13.0": "13.0.2",
|
||||
"13.0": "13.0.0",
|
||||
}
|
||||
CUDA_ARCHES_CUDNN_VERSION = {
|
||||
"12.6": "9",
|
||||
|
||||
1
.github/workflows/docker-release.yml
vendored
1
.github/workflows/docker-release.yml
vendored
@ -8,6 +8,7 @@ on:
|
||||
- docker.Makefile
|
||||
- .github/workflows/docker-release.yml
|
||||
- .github/scripts/generate_docker_release_matrix.py
|
||||
- .github/scripts/generate_binary_build_matrix.py
|
||||
push:
|
||||
branches:
|
||||
- nightly
|
||||
|
||||
8
.github/workflows/inductor-unittest.yml
vendored
8
.github/workflows/inductor-unittest.yml
vendored
@ -115,10 +115,10 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
14
.github/workflows/inductor.yml
vendored
14
.github/workflows/inductor.yml
vendored
@ -84,13 +84,13 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
||||
]}
|
||||
build-additional-packages: "vision audio torchao"
|
||||
|
||||
3
.github/workflows/trunk.yml
vendored
3
.github/workflows/trunk.yml
vendored
@ -204,6 +204,7 @@ jobs:
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
@ -221,7 +222,7 @@ jobs:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
|
||||
secrets: inherit
|
||||
|
||||
inductor-build:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,6 +127,7 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
@ -211,7 +211,6 @@ exclude_patterns = [
|
||||
'**/*pb.h',
|
||||
'**/*inl.h',
|
||||
'aten/src/ATen/cpu/FlushDenormal.cpp',
|
||||
'aten/src/ATen/cpu/Utils.cpp',
|
||||
'aten/src/ATen/cpu/vml.h',
|
||||
'aten/src/ATen/CPUFixedAllocator.h',
|
||||
'aten/src/ATen/Parallel*.h',
|
||||
@ -230,8 +229,6 @@ exclude_patterns = [
|
||||
'c10/util/win32-headers.h',
|
||||
'c10/test/**/*.h',
|
||||
'third_party/**/*',
|
||||
'torch/csrc/api/include/torch/nn/modules/common.h',
|
||||
'torch/csrc/api/include/torch/linalg.h',
|
||||
'torch/csrc/autograd/generated/**',
|
||||
'torch/csrc/distributed/**/*.cu',
|
||||
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
|
||||
@ -243,7 +240,6 @@ exclude_patterns = [
|
||||
'torch/csrc/utils/generated_serialization_types.h',
|
||||
'torch/csrc/utils/pythoncapi_compat.h',
|
||||
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
|
||||
'aten/src/ATen/ExpandBase.h',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
@ -1752,15 +1748,3 @@ command = [
|
||||
"python3",
|
||||
"tools/linter/adapters/gb_registry_linter.py",
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'STABLE_SHIM_VERSION'
|
||||
include_patterns = [
|
||||
'torch/csrc/stable/c/shim.h',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/stable_shim_version_linter.py',
|
||||
'--',
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
|
||||
20
SECURITY.md
20
SECURITY.md
@ -1,7 +1,7 @@
|
||||
# Security Policy
|
||||
|
||||
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
|
||||
- [**Using Pytorch Securely**](#using-pytorch-securely)
|
||||
- [**Using PyTorch Securely**](#using-pytorch-securely)
|
||||
- [Untrusted models](#untrusted-models)
|
||||
- [TorchScript models](#torchscript-models)
|
||||
- [Untrusted inputs](#untrusted-inputs)
|
||||
@ -10,28 +10,28 @@
|
||||
- [**CI/CD security principles**](#cicd-security-principles)
|
||||
## Reporting Security Issues
|
||||
|
||||
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
|
||||
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
|
||||
|
||||
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
||||
|
||||
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
|
||||
|
||||
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
|
||||
## Using Pytorch Securely
|
||||
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
## Using PyTorch Securely
|
||||
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
|
||||
### Untrusted models
|
||||
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
|
||||
|
||||
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
|
||||
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
|
||||
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
|
||||
|
||||
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
|
||||
|
||||
### TorchScript models
|
||||
|
||||
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
|
||||
### Untrusted inputs during training and prediction
|
||||
|
||||
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
|
||||
|
||||
### Data privacy
|
||||
|
||||
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
|
||||
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
|
||||
|
||||
### Using distributed features
|
||||
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
|
||||
@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP()
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
These const variables defined the fp32 precisions for different backend
|
||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||
@ -41,16 +39,6 @@ namespace {
|
||||
->rnn
|
||||
*/
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
|
||||
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
|
||||
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||
);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Float32Backend str2backend(const std::string& name) {
|
||||
if (name == "generic")
|
||||
return Float32Backend::GENERIC;
|
||||
@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
||||
} else {
|
||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_new;
|
||||
}
|
||||
|
||||
@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||
"We suggest only using the new API for matmul precision. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
auto match = [this](const std::string & s_) {
|
||||
warn_deprecated_fp32_precision_api();
|
||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
|
||||
@ -191,22 +191,37 @@ inline void convert(const at::Half* src, bool* dst, int64_t n) {
|
||||
}
|
||||
|
||||
#endif
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, float)
|
||||
CONVERT_TEMPLATE(bfloat16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(float, bfloat16_t)
|
||||
CONVERT_TEMPLATE(double, bfloat16_t)
|
||||
|
||||
template <typename to_type>
|
||||
inline void convertFromBf16Impl(
|
||||
const c10::BFloat16* __restrict src,
|
||||
to_type* __restrict dst,
|
||||
int64_t n) {
|
||||
const uint16_t* srcPtr = reinterpret_cast<const uint16_t*>(src);
|
||||
uint64_t len = static_cast<uint64_t>(n);
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
uint32_t tmp = static_cast<uint32_t>(srcPtr[i]) << 16;
|
||||
float tmpF;
|
||||
__builtin_memcpy(&tmpF, &tmp, sizeof(float));
|
||||
dst[i] = static_cast<to_type>(tmpF);
|
||||
}
|
||||
}
|
||||
#define CONVERT_FROM_BF16_TEMPLATE(to_type) \
|
||||
template <> \
|
||||
inline void convert(const c10::BFloat16* src, to_type* dst, int64_t n) { \
|
||||
return convertFromBf16Impl<to_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_FROM_BF16_TEMPLATE(uint8_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int8_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int16_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int32_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int64_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(float)
|
||||
CONVERT_FROM_BF16_TEMPLATE(double)
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
||||
#endif
|
||||
|
||||
inline void convertBoolToBfloat16Impl(
|
||||
const bool* __restrict src,
|
||||
@ -247,8 +262,6 @@ inline void convert(const c10::BFloat16* src, bool* dst, int64_t n) {
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
||||
@ -156,6 +156,10 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual void mtiagraphDestroy(int64_t handle) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
@ -92,7 +92,8 @@ void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
|
||||
void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, double beta) {
|
||||
ScalarType dtype = iter.dtype(0);
|
||||
if (dtype == kBFloat16) {
|
||||
if (at::isReducedFloatingType(dtype)) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "smooth_l1_backward_cpu_out", [&]() {
|
||||
auto norm_val = norm.to<float>();
|
||||
float beta_val(beta);
|
||||
auto norm_val_vec = Vectorized<float>(norm_val);
|
||||
@ -101,9 +102,9 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
const auto zero_vec = Vectorized<float>(0);
|
||||
const auto pos_1_vec = Vectorized<float>(1);
|
||||
cpu_kernel_vec(iter,
|
||||
[=](BFloat16 input, BFloat16 target, BFloat16 grad_output) -> BFloat16 {
|
||||
[=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
|
||||
const auto x = float(input) - float(target);
|
||||
if (x <= -beta){
|
||||
if (x <= -beta) {
|
||||
return -norm_val * float(grad_output);
|
||||
}else if (x >= beta){
|
||||
return norm_val * float(grad_output);
|
||||
@ -112,14 +113,14 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
}
|
||||
},
|
||||
[norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
|
||||
Vectorized<BFloat16> input, Vectorized<BFloat16> target, Vectorized<BFloat16> grad_output) -> Vectorized<BFloat16> {
|
||||
Vectorized<scalar_t> input, Vectorized<scalar_t> target, Vectorized<scalar_t> grad_output) -> Vectorized<scalar_t> {
|
||||
// using two blendv calls to simulate the 3 cases
|
||||
// 1 if x >= beta
|
||||
// -1 if x <= -beta
|
||||
// x / beta if |x| < beta
|
||||
auto [input0, input1] = convert_bfloat16_float(input);
|
||||
auto [target0, target1] = convert_bfloat16_float(target);
|
||||
auto [grad_output0, grad_output1] = convert_bfloat16_float(grad_output);
|
||||
auto [input0, input1] = convert_to_float(input);
|
||||
auto [target0, target1] = convert_to_float(target);
|
||||
auto [grad_output0, grad_output1] = convert_to_float(grad_output);
|
||||
auto x = input0 - target0;
|
||||
auto pos_or_neg_1_vec = Vectorized<float>::blendv(
|
||||
neg_1_vec, pos_1_vec, x > zero_vec);
|
||||
@ -135,9 +136,10 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
output = Vectorized<float>::blendv(
|
||||
x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
|
||||
input1 = norm_val_vec * output * grad_output1;
|
||||
return convert_float_bfloat16(input0, input1);
|
||||
return convert_from_float<scalar_t>(input0, input1);
|
||||
}
|
||||
);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||
auto norm_val = norm.to<scalar_t>();
|
||||
|
||||
@ -205,8 +205,8 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
|
||||
// and the leading stride is at least max(1, other dim length), so we might
|
||||
// end up with contiguous cols but not rows (i.e. holes between different rows)
|
||||
// and vice versa.
|
||||
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32
|
||||
&& mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32
|
||||
&& (
|
||||
// filter by dtype
|
||||
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
|
||||
|
||||
@ -59,6 +59,24 @@
|
||||
// forward declare
|
||||
class cublasCommonArgs;
|
||||
|
||||
#ifndef _WIN32
|
||||
namespace fbgemm_gpu {
|
||||
|
||||
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
|
||||
// To update supported ops means a submodule bump, which is.. painful. Instead, we
|
||||
// can simply forward-declare the methods we want to use.. Works at least as a short-term
|
||||
// thing, but should still be fixed somewhere/somehow.
|
||||
at::Tensor f4f4bf16(
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
std::optional<at::Tensor>,
|
||||
bool use_mx);
|
||||
|
||||
} // namespace fbgemm_gpu
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
@ -767,33 +785,6 @@ _scaled_rowwise_rowwise(
|
||||
return out;
|
||||
}
|
||||
|
||||
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
|
||||
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
|
||||
// and strides become somewhat meaningless
|
||||
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
|
||||
if (scale_type == ScalingType::BlockWise1x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
|
||||
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
|
||||
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
} else if (scale_type == ScalingType::BlockWise128x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(
|
||||
scale,
|
||||
0,
|
||||
ceil_div<int64_t>(t.size(0), 128),
|
||||
ceil_div<int64_t>(t.size(1), 128)),
|
||||
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
TORCH_CHECK(check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
|
||||
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
_check_deepseek_support() {
|
||||
#ifndef USE_ROCM
|
||||
@ -806,7 +797,7 @@ _check_deepseek_support() {
|
||||
}
|
||||
// Only in cublasLt >= 12.9
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
|
||||
CUBLAS_VERSION >= 120900 && cublasLtGetVersion() >= 120900,
|
||||
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
|
||||
);
|
||||
#endif
|
||||
@ -823,23 +814,61 @@ _scaled_block1x128_block1x128(
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
// check types
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(scale_a.size(1) == 1 && scale_b.stride(1) == 1)
|
||||
),
|
||||
"scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -861,24 +890,65 @@ _scaled_block128x128_block1x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [round_up(K // 128, 4), M // 128], stride: [M // 128, 1]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(M, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ",
|
||||
ceil_div<int64_t>(M, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", N, "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise128x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -900,24 +970,62 @@ _scaled_block1x128_block128x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [round_up(K // 128, 4) x N // 128], stride: [1, N // 128]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
|
||||
int64_t M = mat_a.size(0);
|
||||
int64_t K = mat_a.size(1);
|
||||
int64_t N = mat_b.size(1);
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", M, "); got ", scale_b.strides()
|
||||
);
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(N, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ", ceil_div<int64_t>(N, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise128x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -997,26 +1105,47 @@ _scaled_mxfp4_mxfp4(
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
|
||||
auto K_multiplier = 2;
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
// NVIDIA
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
|
||||
#else
|
||||
// NVIDIA
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
@ -1031,11 +1160,30 @@ _scaled_mxfp4_mxfp4(
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
#else
|
||||
// NVIDIA
|
||||
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
|
||||
// but we have one we need to use. Two clear options are to copy into
|
||||
// our output (slow), or use a move-assignment-operator (faster).
|
||||
// However, the compiler can complain about the explicit move preventing
|
||||
// copy elision because the return from f4f4bf16 is a temporary object.
|
||||
// So we don't explicitly move, and trust the compiler here...
|
||||
// In the longer term this should be fixed on the FBGemm side.
|
||||
out = fbgemm_gpu::f4f4bf16(
|
||||
mat_a,
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
std::nullopt, /* global_scale */
|
||||
true /* use_mx */
|
||||
);
|
||||
|
||||
return out;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -1160,17 +1308,20 @@ _scaled_mm_cuda_v2_out(
|
||||
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
|
||||
}
|
||||
|
||||
// Handle fp4 packed-K dimension
|
||||
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
|
||||
|
||||
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.sizes()[1] % 16 == 0,
|
||||
K_multiplier * mat_a.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes()[0],
|
||||
"x",
|
||||
mat_a.sizes()[1],
|
||||
K_multiplier * mat_a.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
mat_b.sizes()[1], ") must be divisible by 16");
|
||||
|
||||
// TODO(slayton): Existing checks, not sure if they should really be here.
|
||||
|
||||
@ -157,10 +157,10 @@ bool onednn_strides_check(const Tensor& src) {
|
||||
return true;
|
||||
|
||||
dnnl_dims_t blocks = {0};
|
||||
int perm[DNNL_MAX_NDIMS] = {0};
|
||||
std::array<int, DNNL_MAX_NDIMS> perm = {0};
|
||||
for (int d = 0; d < md_ndims; ++d) {
|
||||
// no strides check needed for empty tensor
|
||||
if (md_padded_dims[d] == nullptr)
|
||||
if ((*md_padded_dims)[d] == 0)
|
||||
return true;
|
||||
|
||||
// no strides verification for runtime dims
|
||||
@ -178,14 +178,15 @@ bool onednn_strides_check(const Tensor& src) {
|
||||
|
||||
// A custom comparator to yield linear order on perm
|
||||
auto idx_sorter = [&](const int a, const int b) -> bool {
|
||||
if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b])
|
||||
if (strides[a] == strides[b] &&
|
||||
(*md_padded_dims)[a] == (*md_padded_dims)[b])
|
||||
return a < b;
|
||||
else if (strides[a] == strides[b])
|
||||
return md_padded_dims[a] < md_padded_dims[b];
|
||||
return (*md_padded_dims)[a] < (*md_padded_dims)[b];
|
||||
else
|
||||
return strides[a] < strides[b];
|
||||
};
|
||||
std::sort(perm, perm + md_ndims, idx_sorter);
|
||||
std::sort(perm.begin(), perm.begin() + md_ndims, idx_sorter);
|
||||
|
||||
auto min_stride = block_size;
|
||||
for (int idx = 0; idx < md_ndims; ++idx) {
|
||||
@ -199,9 +200,10 @@ bool onednn_strides_check(const Tensor& src) {
|
||||
return false;
|
||||
|
||||
// update min_stride for next iteration
|
||||
const auto padded_dim = *md_padded_dims[d];
|
||||
const auto padded_dim = (*md_padded_dims)[d];
|
||||
min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -370,7 +370,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
|
||||
onValue:-1.0f
|
||||
offValue:0.0f
|
||||
name:nil];
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType);
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]);
|
||||
if (isWeightsArrayValid) {
|
||||
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
|
||||
secondaryTensor:weightTensor
|
||||
@ -705,6 +705,7 @@ static void smooth_l1_loss_template(const Tensor& input,
|
||||
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
|
||||
TORCH_CHECK(input.is_mps());
|
||||
TORCH_CHECK(target.is_mps());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
|
||||
if ((input.numel() == 0) || (target.numel() == 0)) {
|
||||
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
|
||||
return;
|
||||
@ -771,7 +772,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32];
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]];
|
||||
// xn - yn
|
||||
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:targetTensor
|
||||
@ -797,7 +798,8 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
name:@"lossTensor"];
|
||||
MPSGraphTensor* outputTensor = lossTensor;
|
||||
if (reduction == Reduction::Mean) {
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32];
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel()
|
||||
dataType:[lossTensor dataType]];
|
||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
|
||||
}
|
||||
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
|
||||
|
||||
63
aten/src/ATen/xpu/PeerToPeerAccess.cpp
Normal file
63
aten/src/ATen/xpu/PeerToPeerAccess.cpp
Normal file
@ -0,0 +1,63 @@
|
||||
#include <ATen/xpu/PeerToPeerAccess.h>
|
||||
#include <ATen/xpu/XPUContext.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/xpu/XPUCachingAllocator.h>
|
||||
|
||||
namespace at::xpu {
|
||||
|
||||
// p2pAccessEnabled_ is a flattened 2D matrix of size [num_devices x
|
||||
// num_devices].
|
||||
// Each element represents whether device[i] can access device[j]:
|
||||
// 1 -> access allowed
|
||||
// 0 -> access not allowed
|
||||
// -1 -> unknown (not yet queried)
|
||||
static std::vector<int8_t> p2pAccessEnabled_;
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Initializes the peer-to-peer (P2P) access capability cache.
|
||||
void init_p2p_access_cache(c10::DeviceIndex num_devices) {
|
||||
// By default, each device can always access itself (diagonal entries = 1).
|
||||
// For simplicity, all entries are initialized to -1 except the diagonal.
|
||||
static bool once [[maybe_unused]] = [num_devices]() {
|
||||
p2pAccessEnabled_.clear();
|
||||
p2pAccessEnabled_.resize(num_devices * num_devices, -1);
|
||||
|
||||
for (const auto i : c10::irange(num_devices)) {
|
||||
p2pAccessEnabled_[i * num_devices + i] = 1;
|
||||
}
|
||||
return true;
|
||||
}();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
|
||||
|
||||
check_device_index(dev);
|
||||
check_device_index(dev_to_access);
|
||||
|
||||
auto& cache =
|
||||
p2pAccessEnabled_[dev * c10::xpu::device_count() + dev_to_access];
|
||||
|
||||
if (cache != -1) {
|
||||
return static_cast<bool>(cache);
|
||||
}
|
||||
|
||||
// Query the hardware to determine if P2P access is supported
|
||||
cache = static_cast<int8_t>(
|
||||
c10::xpu::get_raw_device(dev).ext_oneapi_can_access_peer(
|
||||
c10::xpu::get_raw_device(dev_to_access),
|
||||
sycl::ext::oneapi::peer_access::access_supported));
|
||||
|
||||
if (cache) {
|
||||
XPUCachingAllocator::enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
return static_cast<bool>(cache);
|
||||
}
|
||||
|
||||
} // namespace at::xpu
|
||||
15
aten/src/ATen/xpu/PeerToPeerAccess.h
Normal file
15
aten/src/ATen/xpu/PeerToPeerAccess.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace at::xpu {
|
||||
namespace detail {
|
||||
void init_p2p_access_cache(c10::DeviceIndex num_devices);
|
||||
} // namespace detail
|
||||
|
||||
TORCH_XPU_API bool get_p2p_access(
|
||||
c10::DeviceIndex dev,
|
||||
c10::DeviceIndex dev_to_access);
|
||||
|
||||
} // namespace at::xpu
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/xpu/PeerToPeerAccess.h>
|
||||
#include <ATen/xpu/PinnedMemoryAllocator.h>
|
||||
#include <ATen/xpu/XPUContext.h>
|
||||
#include <ATen/xpu/XPUDevice.h>
|
||||
@ -12,6 +13,7 @@ void XPUHooks::init() const {
|
||||
C10_LOG_API_USAGE_ONCE("aten.init.xpu");
|
||||
const auto device_count = c10::xpu::device_count_ensure_non_zero();
|
||||
c10::xpu::XPUCachingAllocator::init(device_count);
|
||||
at::xpu::detail::init_p2p_access_cache(device_count);
|
||||
}
|
||||
|
||||
bool XPUHooks::hasXPU() const {
|
||||
|
||||
@ -11,6 +11,11 @@ def remove_cuda(config_list):
|
||||
return [config for config in config_list if cuda_config not in config]
|
||||
|
||||
|
||||
def remove_cpu(config_list):
|
||||
cpu_config = {"device": "cpu"}
|
||||
return [config for config in config_list if cpu_config not in config]
|
||||
|
||||
|
||||
# Configs for conv-1d ops
|
||||
conv_1d_configs_short = op_bench.config_list(
|
||||
attr_names=["IC", "OC", "kernel", "stride", "N", "L"],
|
||||
@ -127,6 +132,18 @@ conv_3d_configs_short = op_bench.config_list(
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
conv_3d_configs_long = op_bench.cross_product_configs(
|
||||
IC=[16, 32],
|
||||
OC=[32, 64],
|
||||
kernel=[3, 5],
|
||||
stride=[1, 2],
|
||||
N=[1],
|
||||
D=[128],
|
||||
H=[128],
|
||||
W=[128],
|
||||
device=["cpu", "cuda"],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
linear_configs_short = op_bench.config_list(
|
||||
attr_names=["N", "IN", "OUT"],
|
||||
|
||||
@ -38,6 +38,10 @@ class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase):
|
||||
op_bench.generate_pt_test(
|
||||
configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
|
||||
)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
configs.remove_cpu(configs.conv_1d_configs_short + configs.conv_1d_configs_long),
|
||||
Conv1dBenchmark,
|
||||
)
|
||||
|
||||
|
||||
if not torch.backends.mkldnn.is_acl_available():
|
||||
@ -103,6 +107,20 @@ op_bench.generate_pt_test(
|
||||
configs.conv_2d_pw_configs_short + configs.conv_2d_pw_configs_long,
|
||||
Conv2dPointwiseBenchmark,
|
||||
)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
configs.remove_cpu(configs.conv_2d_configs_short + configs.conv_2d_configs_long),
|
||||
Conv2dBenchmark,
|
||||
)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
configs.remove_cpu(configs.conv_2d_configs_short + configs.conv_2d_configs_long),
|
||||
ConvTranspose2dBenchmark,
|
||||
)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
configs.remove_cpu(
|
||||
configs.conv_2d_pw_configs_short + configs.conv_2d_pw_configs_long
|
||||
),
|
||||
Conv2dPointwiseBenchmark,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
@ -134,6 +152,12 @@ class ConvTranspose3dBenchmark(op_bench.TorchBenchmarkBase):
|
||||
|
||||
op_bench.generate_pt_test(configs.conv_3d_configs_short, Conv3dBenchmark)
|
||||
op_bench.generate_pt_test(configs.conv_3d_configs_short, ConvTranspose3dBenchmark)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
configs.remove_cpu(configs.conv_3d_configs_long), Conv3dBenchmark
|
||||
)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
configs.remove_cpu(configs.conv_3d_configs_long), ConvTranspose3dBenchmark
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -929,6 +929,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/dynamo/guards.cpp",
|
||||
"torch/csrc/dynamo/utils.cpp",
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/dynamo/stackref_bridge.c",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/fx/node.cpp",
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
|
||||
@ -21,13 +21,20 @@ using stream_set = ska::flat_hash_set<xpu::XPUStream>;
|
||||
struct Block;
|
||||
typedef bool (*Comparison)(const Block*, const Block*);
|
||||
bool BlockComparatorSize(const Block* a, const Block* b);
|
||||
bool BlockComparatorAddress(const Block* a, const Block* b);
|
||||
|
||||
struct BlockPool {
|
||||
BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {}
|
||||
BlockPool(bool small)
|
||||
: blocks(BlockComparatorSize),
|
||||
unmapped(BlockComparatorAddress),
|
||||
is_small(small) {}
|
||||
std::set<Block*, Comparison> blocks;
|
||||
std::set<Block*, Comparison> unmapped;
|
||||
const bool is_small;
|
||||
};
|
||||
|
||||
struct ExpandableSegment;
|
||||
|
||||
struct Block {
|
||||
DeviceIndex device;
|
||||
sycl::queue* queue{nullptr}; // underlying queue of the allocation stream
|
||||
@ -37,9 +44,11 @@ struct Block {
|
||||
BlockPool* pool{nullptr}; // owning memory pool
|
||||
void* ptr{nullptr}; // memory address
|
||||
bool allocated{false}; // in-use flag
|
||||
bool mapped{true}; // True if this Block is backed by physical pages
|
||||
Block* prev{nullptr}; // prev block if split from a larger allocation
|
||||
Block* next{nullptr}; // next block if split from a larger allocation
|
||||
int event_count{0}; // number of outstanding XPU events
|
||||
ExpandableSegment* expandable_segment{nullptr}; // owning expandable segment
|
||||
|
||||
Block(
|
||||
DeviceIndex device,
|
||||
@ -66,6 +75,20 @@ struct Block {
|
||||
bool is_split() const {
|
||||
return (prev != nullptr) || (next != nullptr);
|
||||
}
|
||||
|
||||
// Inserts this block between two existing blocks with [before, this, after].
|
||||
void splice(Block* before, Block* after) {
|
||||
if (before) {
|
||||
TORCH_INTERNAL_ASSERT(before->next == after);
|
||||
before->next = this;
|
||||
}
|
||||
prev = before;
|
||||
if (after) {
|
||||
TORCH_INTERNAL_ASSERT(after->prev == before);
|
||||
after->prev = this;
|
||||
}
|
||||
next = after;
|
||||
}
|
||||
};
|
||||
|
||||
bool BlockComparatorSize(const Block* a, const Block* b) {
|
||||
@ -80,6 +103,221 @@ bool BlockComparatorSize(const Block* a, const Block* b) {
|
||||
reinterpret_cast<uintptr_t>(b->ptr);
|
||||
}
|
||||
|
||||
bool BlockComparatorAddress(const Block* a, const Block* b) {
|
||||
if (a->queue != b->queue) {
|
||||
return reinterpret_cast<uintptr_t>(a->queue) <
|
||||
reinterpret_cast<uintptr_t>(b->queue);
|
||||
}
|
||||
return reinterpret_cast<uintptr_t>(a->ptr) <
|
||||
reinterpret_cast<uintptr_t>(b->ptr);
|
||||
}
|
||||
|
||||
// Represents a contiguous virtual memory segment mapped for allocation.
|
||||
struct SegmentRange {
|
||||
SegmentRange(void* addr, size_t bytes)
|
||||
: ptr(static_cast<char*>(addr)), size(bytes) {}
|
||||
char* ptr; // Starting address of the mapped range.
|
||||
size_t size; // Size in bytes of the mapped range.
|
||||
};
|
||||
|
||||
struct ExpandableSegment {
|
||||
ExpandableSegment(
|
||||
c10::DeviceIndex device,
|
||||
std::optional<sycl::queue*> queue,
|
||||
size_t segment_size,
|
||||
std::vector<c10::DeviceIndex> peers)
|
||||
: device_(device),
|
||||
queue_(queue),
|
||||
// 2MB for small pool, 20MB for large pool
|
||||
segment_size_(segment_size),
|
||||
peers_(std::move(peers)) {
|
||||
const auto device_total =
|
||||
c10::xpu::get_raw_device(device)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
// The extra 1/8 allows flexibility for remapping or moving pages within the
|
||||
// segment when unmapping earlier regions.
|
||||
constexpr float kVirtualMemOversubscriptFactor = 1.125f; // 1 + 1/8
|
||||
max_handles_ = numSegments(device_total * kVirtualMemOversubscriptFactor);
|
||||
ptr_ = sycl::ext::oneapi::experimental::reserve_virtual_mem(
|
||||
segment_size_ * max_handles_, xpu::get_device_context());
|
||||
}
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(ExpandableSegment);
|
||||
ExpandableSegment(ExpandableSegment&&) = delete;
|
||||
ExpandableSegment& operator=(ExpandableSegment&&) = delete;
|
||||
|
||||
// Maps a virtual memory range to physical memory.
|
||||
SegmentRange map(SegmentRange range) {
|
||||
auto begin = segmentLeft(range.ptr);
|
||||
auto end = segmentRight(range.ptr + range.size);
|
||||
TORCH_INTERNAL_ASSERT(ptr() + begin * segment_size_ == range.ptr);
|
||||
if (begin == end) {
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
// Ensure handles_ vector is large enough to hold all segments.
|
||||
if (end > handles_.size()) {
|
||||
handles_.resize(end, std::nullopt);
|
||||
}
|
||||
|
||||
// Allocate and map physical memory for each segment.
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
TORCH_INTERNAL_ASSERT(!handles_.at(i));
|
||||
try {
|
||||
// Allocate physical memory for each segment. Construct the physical_mem
|
||||
// in-place to avoid copies.
|
||||
handles_.at(i).emplace(
|
||||
xpu::get_raw_device(device_),
|
||||
xpu::get_device_context(),
|
||||
segment_size_);
|
||||
// Map the allocated physical memory into the virtual address space.
|
||||
handles_.at(i).value().map(
|
||||
ptr_ + i * segment_size_,
|
||||
segment_size_,
|
||||
sycl::ext::oneapi::experimental::address_access_mode::read_write);
|
||||
} catch (const sycl::exception& e) {
|
||||
// Allocation failure: typically sycl::errc::memory_allocation.
|
||||
// Mapping failure: typically sycl::errc::runtime (e.g., OOM due to
|
||||
// over-subscription).
|
||||
// Note: constructing physical_mem may over-subscribe device memory but
|
||||
// not immediately trigger OOM. The actual OOM can occur during map().
|
||||
// Roll back all segments allocated or mapped in this operation.
|
||||
handles_.at(i) = std::nullopt;
|
||||
for (const auto j : c10::irange(begin, i)) {
|
||||
sycl::ext::oneapi::experimental::unmap(
|
||||
reinterpret_cast<void*>(ptr_ + segment_size_ * j),
|
||||
segment_size_,
|
||||
xpu::get_device_context());
|
||||
handles_.at(j) = std::nullopt;
|
||||
}
|
||||
trimHandles();
|
||||
return rangeFromHandles(begin, begin);
|
||||
}
|
||||
}
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
// Unmap a virtual memory range from physical memory.
|
||||
SegmentRange unmap(SegmentRange range) {
|
||||
auto begin = segmentRight(range.ptr);
|
||||
auto end = segmentLeft(range.ptr + range.size);
|
||||
if (begin >= end) {
|
||||
return SegmentRange{range.ptr, 0};
|
||||
}
|
||||
unmapHandles(begin, end);
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
// Returns the base pointer of the virtual memory segment.
|
||||
char* ptr() const {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<char*>(ptr_);
|
||||
}
|
||||
|
||||
// Returns the total size of the virtual memory segment.
|
||||
size_t size() const {
|
||||
return max_handles_ * segment_size_;
|
||||
}
|
||||
|
||||
~ExpandableSegment() {
|
||||
forEachAllocatedRange(
|
||||
[&](size_t begin, size_t end) { unmapHandles(begin, end); });
|
||||
sycl::ext::oneapi::experimental::free_virtual_mem(
|
||||
ptr_, segment_size_ * max_handles_, xpu::get_device_context());
|
||||
}
|
||||
|
||||
private:
|
||||
// Unmaps the physical memory handles in the range [begin, end) from the
|
||||
// segment.
|
||||
void unmapHandles(size_t begin, size_t end) {
|
||||
// Currently, we don't support IPC shared memory with expandable segments.
|
||||
TORCH_INTERNAL_ASSERT(queue_);
|
||||
// As explained in Note [Safe to Free Blocks on BlockPool], additional
|
||||
// synchronization is unnecessary here because the memory is already safe to
|
||||
// release.
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
// Note: physical_mem's destructor does NOT automatically unmap any mapped
|
||||
// ranges. Users must explicitly call unmap on all ranges before
|
||||
// destroying the physical_mem object.
|
||||
sycl::ext::oneapi::experimental::unmap(
|
||||
reinterpret_cast<void*>(ptr_ + segment_size_ * i),
|
||||
segment_size_,
|
||||
xpu::get_device_context());
|
||||
// Here physical_mem object is being destructed.
|
||||
handles_.at(i) = std::nullopt;
|
||||
}
|
||||
trimHandles();
|
||||
}
|
||||
|
||||
// Remove trailing unused handles from the end of handles_.
|
||||
void trimHandles() {
|
||||
while (!handles_.empty() && !handles_.back()) {
|
||||
handles_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
// Iterates over all contiguous ranges of allocated segments in `handles_`,
|
||||
// and invokes the provided function `fn(start, end)` for each range.
|
||||
// Each range is defined as a half-open interval [start, end).
|
||||
void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
|
||||
size_t start = 0;
|
||||
for (const auto i : c10::irange(handles_.size())) {
|
||||
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
|
||||
start = i;
|
||||
}
|
||||
if (handles_.at(i) && (i + 1 == handles_.size() || !handles_.at(i + 1))) {
|
||||
fn(start, i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the number of full segments required to cover `size` bytes.
|
||||
// Rounds up to ensure partial segments are counted.
|
||||
size_t numSegments(size_t size) const {
|
||||
return (size + segment_size_ - 1) / segment_size_;
|
||||
}
|
||||
|
||||
// Returns the index of the segment that contains the pointer `p`,
|
||||
// relative to the base pointer `ptr_`. This is the *inclusive* lower bound
|
||||
// of the segment that includes `p`.
|
||||
size_t segmentLeft(char* p) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
|
||||
size_t offset = p - ptr();
|
||||
return offset / segment_size_;
|
||||
}
|
||||
|
||||
// Returns the index of the segment just *past* the one containing pointer
|
||||
// `p`, relative to the base pointer `ptr_`. This is the *exclusive* upper
|
||||
// bound, useful for [begin, end) style ranges.
|
||||
// If `p` lies exactly on a segment boundary, this is equal to segmentLeft(p).
|
||||
// Otherwise, it rounds up and returns segmentLeft(p) + 1.
|
||||
size_t segmentRight(char* p) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
|
||||
size_t offset = p - ptr();
|
||||
return numSegments(offset);
|
||||
}
|
||||
|
||||
// Constructs a SegmentRange spanning indices [start, end).
|
||||
SegmentRange rangeFromHandles(size_t begin, size_t end) {
|
||||
return SegmentRange(
|
||||
ptr() + segment_size_ * begin, segment_size_ * (end - begin));
|
||||
}
|
||||
|
||||
c10::DeviceIndex device_{-1};
|
||||
std::optional<sycl::queue*> queue_;
|
||||
// Virtual memory address used for reservation.
|
||||
uintptr_t ptr_{0};
|
||||
// Size of each segment in bytes.
|
||||
size_t segment_size_{0};
|
||||
// Maximum number of segments that can be allocated in this segment.
|
||||
size_t max_handles_{0};
|
||||
// Physical memory handles for the segments.
|
||||
std::vector<std::optional<sycl::ext::oneapi::experimental::physical_mem>>
|
||||
handles_{};
|
||||
// Peer devices on which this memory could be accessible, reserved.
|
||||
std::vector<c10::DeviceIndex> peers_{};
|
||||
};
|
||||
|
||||
struct AllocParams {
|
||||
AllocParams(
|
||||
DeviceIndex device,
|
||||
@ -125,10 +363,12 @@ class DeviceCachingAllocator {
|
||||
DeviceIndex device_index;
|
||||
size_t allowed_memory_maximum = 0;
|
||||
bool set_fraction = false;
|
||||
std::vector<ExpandableSegment*> expandable_segments;
|
||||
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
|
||||
|
||||
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
|
||||
if (!src || src->allocated || src->event_count > 0 ||
|
||||
!src->stream_uses.empty()) {
|
||||
!src->stream_uses.empty() || dst->mapped != src->mapped) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -147,7 +387,8 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
const size_t subsumed_size = src->size;
|
||||
dst->size += subsumed_size;
|
||||
auto erased = pool.blocks.erase(src);
|
||||
auto erased =
|
||||
src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
|
||||
delete src;
|
||||
|
||||
@ -230,12 +471,175 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
// Finds the first (lowest-address) block in any segment that has sufficient
|
||||
// contiguous free virtual address space to satisfy `size`. The available
|
||||
// space may span multiple adjacent blocks, which can include both free and
|
||||
// unmapped segments.
|
||||
Block* find_expandable_block(
|
||||
c10::DeviceIndex device,
|
||||
sycl::queue* queue,
|
||||
BlockPool* pool,
|
||||
size_t size) {
|
||||
Block key(device, queue, 0);
|
||||
|
||||
auto allocatable = [](Block* b) {
|
||||
return b && !b->allocated && b->event_count == 0 &&
|
||||
b->stream_uses.empty();
|
||||
};
|
||||
auto has_available_address_space = [&](Block* b) {
|
||||
size_t bytes = 0;
|
||||
while (bytes < size && allocatable(b)) {
|
||||
bytes += b->size;
|
||||
b = b->next;
|
||||
}
|
||||
return bytes >= size;
|
||||
};
|
||||
for (auto it = pool->unmapped.lower_bound(&key);
|
||||
it != pool->unmapped.end() && (*it)->queue == queue;
|
||||
++it) {
|
||||
Block* c = *it;
|
||||
// The unmapped block might have a free mapped block right before it.
|
||||
// By starting from the previous block, we can use both:
|
||||
// [Free Mapped Block] + [Unmapped Block] = More contiguous space
|
||||
if (allocatable(c->prev)) {
|
||||
c = c->prev;
|
||||
}
|
||||
if (has_available_address_space(c)) {
|
||||
return c;
|
||||
}
|
||||
}
|
||||
auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
|
||||
expandable_segments.emplace_back(new ExpandableSegment(
|
||||
device, queue, segment_size, devices_with_peer_access));
|
||||
|
||||
ExpandableSegment* es = expandable_segments.back();
|
||||
Block* candidate = new Block(device, queue, es->size(), pool, es->ptr());
|
||||
candidate->mapped = false;
|
||||
candidate->expandable_segment = es;
|
||||
pool->unmapped.insert(candidate);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
bool map_block(Block* to_map, size_t size) {
|
||||
TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size);
|
||||
auto mapped_range =
|
||||
to_map->expandable_segment->map(SegmentRange{to_map->ptr, size});
|
||||
// Failed to map the memory
|
||||
if (mapped_range.size == 0) {
|
||||
return false;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
mapped_range.ptr == to_map->ptr && mapped_range.size >= size);
|
||||
|
||||
BlockPool& pool = *to_map->pool;
|
||||
pool.unmapped.erase(to_map);
|
||||
to_map->mapped = true;
|
||||
|
||||
if (mapped_range.size < to_map->size) {
|
||||
// to_map -> remaining -> to_map->next(?)
|
||||
Block* remaining = new Block(
|
||||
to_map->device,
|
||||
to_map->queue,
|
||||
to_map->size - mapped_range.size,
|
||||
&pool,
|
||||
static_cast<char*>(to_map->ptr) + mapped_range.size);
|
||||
remaining->mapped = false;
|
||||
remaining->expandable_segment = to_map->expandable_segment;
|
||||
remaining->splice(to_map, to_map->next);
|
||||
pool.unmapped.insert(remaining);
|
||||
to_map->size = mapped_range.size;
|
||||
}
|
||||
|
||||
try_merge_blocks(to_map, to_map->prev, pool);
|
||||
try_merge_blocks(to_map, to_map->next, pool);
|
||||
|
||||
pool.blocks.insert(to_map);
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].increase(mapped_range.size);
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Block* try_allocate_expandable_block(
|
||||
c10::DeviceIndex device,
|
||||
sycl::queue* queue,
|
||||
BlockPool* pool,
|
||||
size_t size) {
|
||||
// Candidate points to the start of a chain of contiguous blocks with
|
||||
// sufficient virtual address space (>= size). The chain may consist of:
|
||||
// Case 1: [Unmapped Block] -> null
|
||||
// Case 2: [Unmapped Block] -> [Free Mapped Block]
|
||||
// Case 3: [Free Mapped Block] -> [Unmapped Block]
|
||||
Block* candidate = find_expandable_block(device, queue, pool, size);
|
||||
|
||||
// Map first block if unmapped (Case 1 & 2), use std::min to avoid
|
||||
// over-mapping.
|
||||
if (!candidate->mapped &&
|
||||
!map_block(candidate, std::min(candidate->size, size))) {
|
||||
return nullptr;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(candidate->mapped);
|
||||
|
||||
// Map additional blocks until we have enough continuous space (Case 3).
|
||||
// Each map_block() call merges newly mapped blocks with adjacent free
|
||||
// blocks
|
||||
while (candidate->size < size) {
|
||||
auto remaining = size - candidate->size;
|
||||
auto new_candidate = candidate->next;
|
||||
// Map only what we need from the `new_candidate` block.
|
||||
if (!map_block(new_candidate, std::min(remaining, new_candidate->size))) {
|
||||
return nullptr;
|
||||
}
|
||||
candidate = new_candidate;
|
||||
}
|
||||
|
||||
// Remove from the free pool; block will be marked as `allocated` in
|
||||
// alloc_found_block()
|
||||
pool->blocks.erase(candidate);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
bool get_free_block(AllocParams& p) {
|
||||
BlockPool& pool = *p.pool;
|
||||
auto it = pool.blocks.lower_bound(&p.search_key);
|
||||
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
|
||||
return false;
|
||||
}
|
||||
if ((*it)->expandable_segment) {
|
||||
if (AcceleratorAllocatorConfig::use_expandable_segments()) {
|
||||
// When expandable segments are enabled, consider both the current block
|
||||
// and any immediately adjacent unmapped region as a single expandable
|
||||
// area. For "best fit" allocation, we use the total expandable size
|
||||
// instead of just the block's current size, so that blocks which can
|
||||
// grow into a larger contiguous range are preferred.
|
||||
auto expandable_size = [](Block* b) {
|
||||
// b->next may belong to pool.unmapped (reserved but not mapped)
|
||||
return b->size + (b->next && !b->next->mapped ? b->next->size : 0);
|
||||
};
|
||||
auto next = it;
|
||||
next++;
|
||||
// Looks for the best fit block with expandable size.
|
||||
while ((*it)->expandable_segment && next != pool.blocks.end() &&
|
||||
(*next)->queue == p.queue() &&
|
||||
expandable_size(*next) < expandable_size(*it)) {
|
||||
it = next++;
|
||||
}
|
||||
} else {
|
||||
// Expandable segments were previously enabled, but are now disabled
|
||||
// (e.g. to avoid IPC issues). Skip any expandable blocks and only
|
||||
// find from regular non-expandable segments.
|
||||
do {
|
||||
it++;
|
||||
} while (it != pool.blocks.end() && (*it)->expandable_segment &&
|
||||
(*it)->queue == p.queue());
|
||||
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
p.block = *it;
|
||||
pool.blocks.erase(it);
|
||||
return true;
|
||||
@ -252,6 +656,10 @@ class DeviceCachingAllocator {
|
||||
size >
|
||||
allowed_memory_maximum) {
|
||||
return false;
|
||||
} else if (AcceleratorAllocatorConfig::use_expandable_segments()) {
|
||||
p.block =
|
||||
try_allocate_expandable_block(device, p.queue(), p.pool, p.size());
|
||||
return bool(p.block);
|
||||
}
|
||||
void* ptr = sycl::aligned_alloc_device(
|
||||
kDeviceAlignment,
|
||||
@ -265,6 +673,7 @@ class DeviceCachingAllocator {
|
||||
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].increase(size);
|
||||
});
|
||||
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -283,6 +692,27 @@ class DeviceCachingAllocator {
|
||||
xpu_events.clear();
|
||||
}
|
||||
|
||||
void release_expandable_segment(Block* block) {
|
||||
// See Note [Safe to Free Blocks on BlockPool], additional synchronization
|
||||
// is unnecessary here because this function is only called by
|
||||
// release_cached_blocks().
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
block->size == block->expandable_segment->size(),
|
||||
"block disagrees with segment");
|
||||
TORCH_INTERNAL_ASSERT(!block->mapped);
|
||||
|
||||
auto it = std::find(
|
||||
expandable_segments.begin(),
|
||||
expandable_segments.end(),
|
||||
block->expandable_segment);
|
||||
TORCH_INTERNAL_ASSERT(it != expandable_segments.end());
|
||||
|
||||
expandable_segments.erase(it);
|
||||
block->pool->unmapped.erase(block);
|
||||
delete block->expandable_segment;
|
||||
delete block;
|
||||
}
|
||||
|
||||
void release_block(Block* block) {
|
||||
/*
|
||||
* Note [Safe to Free Blocks on BlockPool]
|
||||
@ -293,6 +723,7 @@ class DeviceCachingAllocator {
|
||||
* We have to do a device-level synchronization before free these blocks to
|
||||
* guarantee that all kernels can access to the blocks have finished.
|
||||
*/
|
||||
TORCH_INTERNAL_ASSERT(!block->expandable_segment);
|
||||
sycl::free(block->ptr, xpu::get_device_context());
|
||||
auto* pool = block->pool;
|
||||
pool->blocks.erase(block);
|
||||
@ -305,15 +736,80 @@ class DeviceCachingAllocator {
|
||||
delete block;
|
||||
}
|
||||
|
||||
void unmap_block(Block* block) {
|
||||
auto unmapped =
|
||||
block->expandable_segment->unmap(SegmentRange{block->ptr, block->size});
|
||||
if (unmapped.size == 0) {
|
||||
return;
|
||||
}
|
||||
block->pool->blocks.erase(block);
|
||||
|
||||
ptrdiff_t before_size = unmapped.ptr - static_cast<char*>(block->ptr);
|
||||
if (before_size > 0) {
|
||||
// If the actual unmapped region starts after block->ptr due to alignment,
|
||||
// the region before unmapped.ptr is still mapped.
|
||||
// [Prev Block?] -> [Before Block] -> [Unmapped Block]
|
||||
Block* before_free = new Block(
|
||||
block->device, block->queue, before_size, block->pool, block->ptr);
|
||||
before_free->expandable_segment = block->expandable_segment;
|
||||
before_free->splice(block->prev, block);
|
||||
block->pool->blocks.insert(before_free);
|
||||
}
|
||||
|
||||
auto after_size = block->size - (before_size + unmapped.size);
|
||||
if (after_size > 0) {
|
||||
// If the actual unmapped region ends before block->ptr + block->size,
|
||||
// the region after (unmapped.ptr + unmapped.size) is still mapped.
|
||||
// [Unmapped Block] -> [After Block] -> [Next Block?]
|
||||
Block* after_free = new Block(
|
||||
block->device,
|
||||
block->queue,
|
||||
after_size,
|
||||
block->pool,
|
||||
unmapped.ptr + unmapped.size);
|
||||
after_free->expandable_segment = block->expandable_segment;
|
||||
after_free->splice(block, block->next);
|
||||
block->pool->blocks.insert(after_free);
|
||||
}
|
||||
|
||||
// [Before Mapped Block?] -> [Unmapped Block] -> [After Mapped Block?]
|
||||
block->ptr = unmapped.ptr;
|
||||
block->size = unmapped.size;
|
||||
block->mapped = false;
|
||||
|
||||
try_merge_blocks(block, block->prev, *block->pool);
|
||||
try_merge_blocks(block, block->next, *block->pool);
|
||||
block->pool->unmapped.insert(block);
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].decrease(unmapped.size);
|
||||
});
|
||||
}
|
||||
|
||||
void release_blocks(BlockPool& pool) {
|
||||
std::vector<Block*> to_unmap;
|
||||
// Frees all non-split blocks in the given pool.
|
||||
auto it = pool.blocks.begin();
|
||||
while (it != pool.blocks.end()) {
|
||||
Block* block = *it;
|
||||
++it;
|
||||
if (!block->prev && !block->next) {
|
||||
if (block->expandable_segment) {
|
||||
// unmap_block() modifies the free pool, so collect items to free first
|
||||
// to avoid iterator invalidation.
|
||||
to_unmap.push_back(block);
|
||||
} else if (!block->prev && !block->next) {
|
||||
release_block(block);
|
||||
}
|
||||
}
|
||||
for (Block* block : to_unmap) {
|
||||
unmap_block(block);
|
||||
// After unmap_block(), expandable segment blocks with no neighbors are
|
||||
// also released.
|
||||
if (!block->prev && !block->next) {
|
||||
release_expandable_segment(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool release_cached_blocks() {
|
||||
@ -328,7 +824,8 @@ class DeviceCachingAllocator {
|
||||
|
||||
bool should_split(const Block* block, size_t size) {
|
||||
size_t remaining = block->size - size;
|
||||
if (block->pool->is_small) {
|
||||
if (block->pool->is_small ||
|
||||
AcceleratorAllocatorConfig::use_expandable_segments()) {
|
||||
return remaining >= kMinBlockSize;
|
||||
} else {
|
||||
return remaining > kSmallSize;
|
||||
@ -361,6 +858,7 @@ class DeviceCachingAllocator {
|
||||
remaining = block;
|
||||
|
||||
block = new Block(device, queue, size, pool, block->ptr);
|
||||
block->expandable_segment = remaining->expandable_segment;
|
||||
block->prev = remaining->prev;
|
||||
if (block->prev) {
|
||||
block->prev->next = block;
|
||||
@ -599,6 +1097,15 @@ class XPUAllocator : public DeviceAllocator {
|
||||
return block;
|
||||
}
|
||||
|
||||
void assertValidDevice(DeviceIndex device) {
|
||||
const auto device_num = device_allocators.size();
|
||||
TORCH_CHECK(
|
||||
0 <= device && device < static_cast<int64_t>(device_num),
|
||||
"Invalid device argument ",
|
||||
device,
|
||||
": did you call init?");
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
|
||||
|
||||
@ -711,15 +1218,6 @@ class XPUAllocator : public DeviceAllocator {
|
||||
xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
|
||||
}
|
||||
|
||||
void assertValidDevice(DeviceIndex device) {
|
||||
const auto device_num = device_allocators.size();
|
||||
TORCH_CHECK(
|
||||
0 <= device && device < static_cast<int64_t>(device_num),
|
||||
"Invalid device argument ",
|
||||
device,
|
||||
": did you call init?");
|
||||
}
|
||||
|
||||
DeviceStats getDeviceStats(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getStats();
|
||||
@ -735,6 +1233,13 @@ class XPUAllocator : public DeviceAllocator {
|
||||
device_allocators[device]->resetAccumulatedStats();
|
||||
}
|
||||
|
||||
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
assertValidDevice(dev);
|
||||
assertValidDevice(dev_to_access);
|
||||
c10::xpu::get_raw_device(dev).ext_oneapi_enable_peer_access(
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
@ -793,6 +1298,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
|
||||
return allocator.recordStream(dataPtr, stream);
|
||||
}
|
||||
|
||||
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
return allocator.enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
return allocator.getMemoryFraction(device);
|
||||
}
|
||||
|
||||
@ -25,6 +25,10 @@ C10_XPU_API void raw_delete(void* ptr);
|
||||
|
||||
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
|
||||
|
||||
C10_XPU_API void enablePeerAccess(
|
||||
c10::DeviceIndex dev,
|
||||
c10::DeviceIndex dev_to_access);
|
||||
|
||||
C10_XPU_API double getMemoryFraction(DeviceIndex device);
|
||||
|
||||
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
|
||||
|
||||
@ -206,6 +206,41 @@ templates_path = [
|
||||
os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"),
|
||||
]
|
||||
# TODO: document these and remove them from here.
|
||||
# Fixes the duplicated
|
||||
autosummary_filename_map = {
|
||||
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
|
||||
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
|
||||
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
|
||||
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
|
||||
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
|
||||
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
|
||||
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
|
||||
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
|
||||
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
|
||||
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
|
||||
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
|
||||
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
|
||||
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
|
||||
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
|
||||
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
|
||||
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
|
||||
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
|
||||
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
|
||||
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
|
||||
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
|
||||
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
|
||||
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
|
||||
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
|
||||
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
|
||||
"torch.mtia.stream": "torch.mtia.stream_function",
|
||||
"torch.mtia.Stream": "torch.mtia.Stream_class",
|
||||
"torch.cpu.stream": "torch.cpu.stream_function",
|
||||
"torch.cpu.Stream": "torch.cpu.Stream_class",
|
||||
"torch.cuda.stream": "torch.cuda.stream_function",
|
||||
"torch.cuda.Stream": "torch.cuda.Stream_class",
|
||||
"torch.xpu.stream": "torch.xpu.stream_function",
|
||||
"torch.xpu.Stream": "torch.xpu.Stream_class",
|
||||
}
|
||||
|
||||
coverage_ignore_functions = [
|
||||
# torch
|
||||
@ -3195,6 +3230,11 @@ autodoc_type_aliases = {
|
||||
# Enable overriding of function signatures in the first line of the docstring.
|
||||
autodoc_docstring_signature = True
|
||||
|
||||
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
|
||||
autodoc_default_options = {
|
||||
"exclude-members": "from_bytes, to_bytes",
|
||||
}
|
||||
|
||||
# -- katex javascript in header
|
||||
#
|
||||
# def setup(app):
|
||||
|
||||
@ -253,7 +253,6 @@ regular full-precision tensor.
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
view
|
||||
as_strided
|
||||
|
||||
34
setup.py
34
setup.py
@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1616,6 +1647,8 @@ def main() -> None:
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
cmdclass,
|
||||
@ -1649,6 +1682,7 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
@ -256,23 +256,25 @@ class TestSDPA(NNTestCase):
|
||||
)
|
||||
rand_upward_privateuse1 = rand_upward.to("openreg")
|
||||
grad_input_mask = [True, True, True, True]
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1,
|
||||
q_privateuse1,
|
||||
k_privateuse1,
|
||||
v_privateuse1,
|
||||
attn_mask_privateuse1,
|
||||
grad_input_mask,
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
_grad_q, _grad_k, _grad_v, _grad_attn_mask = (
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1,
|
||||
q_privateuse1,
|
||||
k_privateuse1,
|
||||
v_privateuse1,
|
||||
attn_mask_privateuse1,
|
||||
grad_input_mask,
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -392,11 +392,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
replicate_size = self.world_size // (pp_size)
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape=(replicate_size, 1, pp_size),
|
||||
mesh_dim_names=("replicate", "shard", "pp"),
|
||||
mesh_shape=(replicate_size, pp_size),
|
||||
mesh_dim_names=("replicate", "pp"),
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
dp_mesh = device_mesh["replicate", "shard"]
|
||||
dp_mesh = device_mesh["replicate"]
|
||||
pp_mesh = device_mesh["pp"]
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
|
||||
@ -416,15 +416,13 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
param_dtype=MixedPrecisionParam,
|
||||
reduce_dtype=torch.float32,
|
||||
)
|
||||
replicate_config = {"mp_policy": mp_policy}
|
||||
replicate_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
||||
for layer_id in range(len(partial_model)):
|
||||
replicate(
|
||||
partial_model[layer_id],
|
||||
device_mesh=dp_mesh,
|
||||
**replicate_config,
|
||||
reshard_after_forward=False,
|
||||
)
|
||||
dp_model = replicate(partial_model, device_mesh=dp_mesh, **replicate_config)
|
||||
dp_model = replicate(partial_model, **replicate_config)
|
||||
return dp_model
|
||||
|
||||
# Apply same precision to reference model (without replicate)
|
||||
@ -582,11 +580,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
replicate_size = self.world_size // (pp_size)
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape=(replicate_size, 1, pp_size),
|
||||
mesh_dim_names=("replicate", "shard", "pp"),
|
||||
mesh_shape=(replicate_size, pp_size),
|
||||
mesh_dim_names=("replicate", "pp"),
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
dp_mesh = device_mesh["replicate", "shard"]
|
||||
dp_mesh = device_mesh["replicate"]
|
||||
pp_mesh = device_mesh["pp"]
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
dp_group = device_mesh["replicate"].get_group()
|
||||
@ -648,10 +646,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
for layer_id in range(len(partial_model)):
|
||||
replicate(
|
||||
partial_model[layer_id],
|
||||
device_mesh=dp_mesh,
|
||||
reshard_after_forward=False,
|
||||
mesh=dp_mesh,
|
||||
)
|
||||
dp_model = replicate(partial_model, device_mesh=dp_mesh)
|
||||
dp_model = replicate(partial_model, mesh=dp_mesh)
|
||||
return dp_model
|
||||
|
||||
def pipelined_models_parameters(start_layer, model):
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -14,7 +14,6 @@ from torch.distributed.fsdp import MixedPrecisionPolicy
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
_get_gradient_divide_factors,
|
||||
)
|
||||
from torch.distributed.tensor import Shard
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl_version,
|
||||
SaveForwardInputsModel,
|
||||
@ -46,35 +45,20 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
|
||||
def _init_models_and_optims(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
param_dtype: Optional[torch.dtype],
|
||||
reduce_dtype: Optional[torch.dtype],
|
||||
use_shard_placement_fn,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
||||
largest_dim = -1
|
||||
largest_dim_size = -1
|
||||
for dim, dim_size in enumerate(param.shape):
|
||||
if dim_size > largest_dim_size:
|
||||
largest_dim = dim
|
||||
largest_dim_size = dim_size
|
||||
assert largest_dim >= 0, f"{param.shape}"
|
||||
return Shard(largest_dim)
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=param_dtype, reduce_dtype=reduce_dtype
|
||||
)
|
||||
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mp_policy=mp_policy,
|
||||
shard_placement_fn=shard_placement_fn,
|
||||
)
|
||||
for mlp in model:
|
||||
replicate_fn(mlp)
|
||||
@ -82,27 +66,13 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
||||
return ref_model, ref_optim, model, optim
|
||||
|
||||
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
|
||||
use_shard_placement_fn_vals = [False]
|
||||
if self.world_size == 2:
|
||||
# For world size >2, gradient elements get reduced in different
|
||||
# orders for the baseline vs. dim-1 sharding, leading to numeric
|
||||
# differences for bf16 reduction, so only test world size 2.
|
||||
use_shard_placement_fn_vals.append(True)
|
||||
return use_shard_placement_fn_vals
|
||||
|
||||
@skipIfRocmVersionLessThan((7, 0))
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_compute_dtype(self):
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"param_dtype": [torch.bfloat16, torch.float16],
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_compute_dtype,
|
||||
)
|
||||
@ -110,14 +80,10 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
def _test_compute_dtype(
|
||||
self,
|
||||
param_dtype: torch.dtype,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
use_shard_placement_fn: bool,
|
||||
):
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=None,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -175,39 +141,14 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_reduce_dtype(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": [False, True],
|
||||
},
|
||||
self._test_reduce_dtype_fp32_reduce,
|
||||
)
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_reduce_dtype_bf16_reduce,
|
||||
)
|
||||
self._test_reduce_dtype_fp32_reduce()
|
||||
self._test_reduce_dtype_bf16_reduce()
|
||||
|
||||
def _test_reduce_dtype_fp32_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
):
|
||||
if (
|
||||
self.world_size > 2
|
||||
and isinstance(reshard_after_forward, int)
|
||||
and use_shard_placement_fn
|
||||
):
|
||||
return
|
||||
def _test_reduce_dtype_fp32_reduce(self):
|
||||
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -249,14 +190,12 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
def _test_reduce_dtype_bf16_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
self,
|
||||
):
|
||||
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
group = dist.distributed_c10d._get_default_group()
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -321,12 +260,8 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for mlp in model:
|
||||
replicate(
|
||||
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
replicate(
|
||||
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
replicate(mlp, mp_policy=mp_policy)
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
|
||||
@ -108,84 +108,70 @@ class TestReplicateRegisteredParams(FSDPTestMultiThread):
|
||||
"""Tests the parameter registration after forward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = MLP(3, device)
|
||||
# Since seed is per process, not per thread, we broadcast to ensure
|
||||
# the same parameters across ranks
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
if reshard_after_forward:
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
else:
|
||||
self._assert_tensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
torch.manual_seed(42)
|
||||
model = MLP(3, device)
|
||||
# Since seed is per process, not per thread, we broadcast to ensure
|
||||
# the same parameters across ranks
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model) # root only
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
self._assert_tensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model[0].in_proj)
|
||||
replicate(model[0].out_proj)
|
||||
replicate(model)
|
||||
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
non_root_params = list(model[0].in_proj.parameters()) + list(
|
||||
model[0].out_proj.parameters()
|
||||
)
|
||||
root_params = list(set(model.parameters()) - set(non_root_params))
|
||||
if reshard_after_forward is None:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
elif reshard_after_forward:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_dtensor_params(root_params)
|
||||
else:
|
||||
self._assert_tensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
non_root_params = list(model[0].in_proj.parameters()) + list(
|
||||
model[0].out_proj.parameters()
|
||||
)
|
||||
root_params = list(set(model.parameters()) - set(non_root_params))
|
||||
self._assert_tensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_param_registration_after_backward(self):
|
||||
"""Tests the parameter registration after backward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model = MLP(8, device)
|
||||
replicate(model) # root only
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model = MLP(8, device)
|
||||
replicate(model.in_proj)
|
||||
replicate(model.out_proj)
|
||||
replicate(model)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
|
||||
# need to iterate over the list multiple times
|
||||
@ -287,14 +273,11 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
[(7, 15), (15, 3)],
|
||||
[(16, 17), (17, 8)],
|
||||
],
|
||||
"use_shard_placement_fn": [False],
|
||||
},
|
||||
self._test_train_parity_single_group,
|
||||
)
|
||||
|
||||
def _test_train_parity_single_group(
|
||||
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
|
||||
):
|
||||
def _test_train_parity_single_group(self, lin_shapes: list[tuple[int, int]]):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(
|
||||
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
|
||||
@ -333,7 +316,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False],
|
||||
"test_device_type": [device_type.type],
|
||||
"offload_policy": [OffloadPolicy()],
|
||||
"delay_after_forward": [False, True],
|
||||
@ -354,7 +336,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True], # save CI time
|
||||
"offload_policy": [
|
||||
CPUOffloadPolicy(pin_memory=True),
|
||||
CPUOffloadPolicy(pin_memory=False),
|
||||
@ -371,7 +352,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
|
||||
def _test_train_parity_multi_group(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
offload_policy: OffloadPolicy,
|
||||
test_device_type: str,
|
||||
delay_after_forward: bool,
|
||||
@ -405,13 +385,12 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
mesh = init_device_mesh(
|
||||
test_device_type,
|
||||
(self.world_size, 1),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
(self.world_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
fully_shard_fn = functools.partial(
|
||||
replicate,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mesh=mesh,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for module in model.modules():
|
||||
@ -527,12 +506,10 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
Tests parity when running a module that participates multiple
|
||||
times in forward.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{"reshard_after_forward": [True, False]},
|
||||
self._test_multi_forward_module,
|
||||
)
|
||||
|
||||
def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]):
|
||||
self._test_multi_forward_module()
|
||||
|
||||
def _test_multi_forward_module(self):
|
||||
class MultiForwardModule(nn.Module):
|
||||
def __init__(self, device: torch.device):
|
||||
super().__init__()
|
||||
@ -687,7 +664,6 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False],
|
||||
"checkpoint_impl": ["composable", "utils", "wrapper"],
|
||||
"module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"],
|
||||
"test_device_type": [device_type.type],
|
||||
@ -697,7 +673,6 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
|
||||
def _test_train_parity_with_activation_checkpointing(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
checkpoint_impl: str,
|
||||
module_grouping: str,
|
||||
test_device_type: str,
|
||||
@ -740,12 +715,11 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
# Apply Replicate
|
||||
device_mesh = init_device_mesh(
|
||||
test_device_type,
|
||||
(self.world_size, 1),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
(self.world_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
fsdp_kwargs = {
|
||||
"reshard_after_forward": reshard_after_forward,
|
||||
"device_mesh": device_mesh,
|
||||
"mesh": device_mesh,
|
||||
}
|
||||
if module_grouping == "mem_eff":
|
||||
assert model_args.n_layers == 3
|
||||
@ -809,7 +783,6 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
def test_train_parity_with_shared_params(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
},
|
||||
self._test_train_shared_params,
|
||||
@ -817,7 +790,6 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
|
||||
def _test_train_shared_params(
|
||||
self,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
@ -830,8 +802,8 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
if isinstance(module, TransformerBlock):
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
replicate(module)
|
||||
replicate(model)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
@ -868,11 +840,11 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
with/without resharding after backward.
|
||||
"""
|
||||
|
||||
shard_size, replicate_size = 1, self.world_size
|
||||
replicate_size = self.world_size
|
||||
meshes = init_device_mesh(
|
||||
device_type.type,
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
(replicate_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
@ -928,8 +900,7 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mesh=mesh,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for mlp in model[1:]:
|
||||
@ -1040,8 +1011,8 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
replicate(module, reshard_after_forward=False)
|
||||
replicate(model, reshard_after_forward=False)
|
||||
replicate(module)
|
||||
replicate(model)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
num_microbatches = 3
|
||||
@ -1145,8 +1116,8 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
def init_global_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh(
|
||||
device_type.type,
|
||||
(2, 1, 2),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
|
||||
(2, 2),
|
||||
mesh_dim_names=("dp_replicate", "tp"),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@ -1154,7 +1125,6 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
global_mesh = self.init_global_mesh()
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 5, 16, 17],
|
||||
"foreach": [False],
|
||||
@ -1165,12 +1135,11 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
def _test_replicate_tp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
foreach: bool,
|
||||
):
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate"], global_mesh["tp"]
|
||||
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
|
||||
|
||||
torch.manual_seed(42)
|
||||
@ -1197,8 +1166,8 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
continue
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module, device_mesh=dp_mesh)
|
||||
replicate(model, device_mesh=dp_mesh)
|
||||
replicate(module, mesh=dp_mesh)
|
||||
replicate(model, mesh=dp_mesh)
|
||||
|
||||
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
|
||||
# strided-sharded layers.
|
||||
@ -1229,11 +1198,9 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
|
||||
for _, p in model.named_parameters():
|
||||
self.assertIsInstance(p, DTensor)
|
||||
self.assertEqual(p.device_mesh.ndim, 3)
|
||||
self.assertEqual(len(p.placements), 3)
|
||||
self.assertEqual(
|
||||
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
|
||||
)
|
||||
self.assertEqual(p.device_mesh.ndim, 2)
|
||||
self.assertEqual(len(p.placements), 2)
|
||||
self.assertEqual(p.device_mesh.mesh_dim_names, ("dp_replicate", "tp"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -120,7 +120,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
if i % 2 == 0:
|
||||
self.assertTrue("replicate" in _get_registry(layer))
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
elif i % 2 == 1:
|
||||
self.assertTrue("fully_shard" in _get_registry(layer))
|
||||
for parameter in layer.parameters():
|
||||
@ -197,14 +197,14 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
]
|
||||
|
||||
global_mesh = self.init_replicate_tp_mesh()
|
||||
replicate_mesh = global_mesh["replicate", "shard"]
|
||||
replicate_mesh = global_mesh["replicate"]
|
||||
|
||||
for layer in layers:
|
||||
replicate(layer, device_mesh=replicate_mesh)
|
||||
replicate(layer, mesh=replicate_mesh)
|
||||
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.device_mesh.shape, (2, 1))
|
||||
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
|
||||
self.assertEqual(parameter.device_mesh.shape, (2,))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_replicate_fsdp(self):
|
||||
@ -263,7 +263,6 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
run_subtests(
|
||||
self,
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 16, 17],
|
||||
},
|
||||
@ -273,7 +272,6 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
def _test_train_parity_2d_mlp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
):
|
||||
@ -287,13 +285,12 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
torch.manual_seed(42)
|
||||
model = MLPStack(mlp_dim)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
replicate(ref_model, device_mesh=replicate_shard_mesh)
|
||||
replicate(ref_model, mesh=replicate_mesh)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
||||
model.parallelize(
|
||||
tp_mesh,
|
||||
replicate_shard_mesh,
|
||||
use_activation_checkpointing,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
|
||||
|
||||
|
||||
@ -1,16 +1,26 @@
|
||||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.checkpoint._async_process_executor import (
|
||||
_ProcessBasedAsyncCheckpointExecutor,
|
||||
_ProcessGroupInitInfo,
|
||||
)
|
||||
from torch.distributed.checkpoint.api import CheckpointException
|
||||
from torch.distributed.checkpoint.storage import StorageWriter
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.common_distributed import skip_if_win32
|
||||
from torch.testing._internal.common_utils import (
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -110,47 +120,184 @@ class TestAsyncProcessExecutor(DTensorTestBase):
|
||||
"epoch": 5,
|
||||
}
|
||||
|
||||
# 1. Simulate a failure in creating PG in background process.
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=-1,
|
||||
):
|
||||
with self.assertRaises(ValueError) as _:
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("DCP_USE_PREFIX_STORE", None)
|
||||
|
||||
# 1. Simulate a failure in creating PG in background process.
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=-1,
|
||||
):
|
||||
with self.assertRaises(ValueError) as _:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
)
|
||||
fut.result()
|
||||
|
||||
# 2. Attempt save with failing storage writer
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=get_free_port(),
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="fail_once"),
|
||||
)
|
||||
fut.result()
|
||||
self.assertIn(
|
||||
"fail_once policy triggered failure", str(fut.exception())
|
||||
)
|
||||
# Verify new process was created for this attempt
|
||||
if dist.get_rank() == 0:
|
||||
mock_get_free_port.assert_called_once()
|
||||
|
||||
# 2. Attempt save with failing storage writer
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=get_free_port(),
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="fail_once"),
|
||||
)
|
||||
self.assertIn("fail_once policy triggered failure", str(fut.exception()))
|
||||
# Verify new process was created for this attempt
|
||||
if dist.get_rank() == 0:
|
||||
mock_get_free_port.assert_called_once()
|
||||
# 3. Second save attempt with successful storage writer - process should still be alive
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="success"),
|
||||
)
|
||||
result = fut.result()
|
||||
# Verify process is still alive
|
||||
mock_get_free_port.assert_not_called()
|
||||
# Verify successful save
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
# 3. Second save attempt with successful storage writer - process should still be alive
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="success"),
|
||||
)
|
||||
result = fut.result()
|
||||
# Verify process is still alive
|
||||
mock_get_free_port.assert_not_called()
|
||||
# Verify successful save
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
class TestAsyncProcessExecutorPrefixStore(TestCase):
|
||||
@skip_if_win32()
|
||||
@retry_on_connect_failures
|
||||
def test_checkpoint_save_with_prefix_store_enabled(self) -> None:
|
||||
"""Test that checkpoint save works when DCP_USE_PREFIX_STORE is enabled."""
|
||||
|
||||
test_state_dict = {
|
||||
"model": {"weight": torch.randn(4, 4), "bias": torch.randn(4)},
|
||||
"optimizer": {"param_groups": [{"lr": 0.01}]},
|
||||
"epoch": 5,
|
||||
}
|
||||
|
||||
master_addr = "localhost"
|
||||
master_port = str(common.find_free_port())
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DCP_USE_PREFIX_STORE": "1",
|
||||
"MASTER_ADDR": master_addr,
|
||||
"MASTER_PORT": master_port,
|
||||
},
|
||||
):
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port"
|
||||
) as mock_get_free_port:
|
||||
dist.init_process_group(
|
||||
backend=dist.Backend.GLOO,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="success"),
|
||||
)
|
||||
result = fut.result()
|
||||
self.assertIsNotNone(result)
|
||||
mock_get_free_port.assert_not_called()
|
||||
|
||||
|
||||
class TestProcessGroupInitInfo(DTensorTestBase):
|
||||
"""Test suite for _ProcessGroupInitInfo."""
|
||||
|
||||
@with_comms
|
||||
def test_process_group_init_info_with_default_pg(self) -> None:
|
||||
"""Test that ProcessGroupInitInfo correctly initializes."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("DCP_USE_PREFIX_STORE", None)
|
||||
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
|
||||
self.assertEqual(pg_init_info.global_rank, dist.get_rank())
|
||||
self.assertEqual(pg_init_info.world_size, dist.get_world_size())
|
||||
self.assertIsNotNone(pg_init_info.tcp_store_master_addr)
|
||||
self.assertGreater(pg_init_info.tcp_store_master_port, 0)
|
||||
self.assertEqual(pg_init_info.use_prefix_store, False)
|
||||
|
||||
@with_comms
|
||||
def test_process_group_init_info_with_prefix_store_env_var(self) -> None:
|
||||
"""Test that ProcessGroupInitInfo handles DCP_USE_PREFIX_STORE environment variable."""
|
||||
|
||||
# Flag enabled, addr/port correctly defined
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DCP_USE_PREFIX_STORE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
},
|
||||
):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertTrue(pg_init_info.use_prefix_store)
|
||||
|
||||
# Missing port
|
||||
with patch.dict(
|
||||
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_ADDR": "localhost"}
|
||||
):
|
||||
with self.assertRaises(CheckpointException):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
# Missing addr
|
||||
with patch.dict(
|
||||
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_PORT": "12345"}
|
||||
):
|
||||
with self.assertRaises(CheckpointException):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
# Invalid port
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DCP_USE_PREFIX_STORE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "a",
|
||||
},
|
||||
):
|
||||
with self.assertRaises(CheckpointException):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
|
||||
@with_comms
|
||||
def test_process_group_init_info_without_prefix_store_env_var(self) -> None:
|
||||
"""Test that ProcessGroupInitInfo defaults to not using prefix store."""
|
||||
|
||||
# Env var set to 0
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "0"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
# Missing env var
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("DCP_USE_PREFIX_STORE", None)
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
# Invalid env var
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "2"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "true"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "false"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": ""}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -415,6 +415,15 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
|
||||
)
|
||||
|
||||
with DebugMode(record_stack_trace=True) as debug_mode:
|
||||
out = mod(inp).sum()
|
||||
out.backward()
|
||||
|
||||
sum_op = [
|
||||
op for op in debug_mode.operators if str(op.op) == "aten.sum.dim_IntList"
|
||||
][-1]
|
||||
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
@ -1019,6 +1019,28 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
except ValueError:
|
||||
self.fail("Unexpected ValueError raised with run_check=False")
|
||||
|
||||
@with_comms
|
||||
def test_as_strided_identity(self):
|
||||
# Test calling as_strided with the same size/stride/offset as input tensor
|
||||
# This should be a no-op but currently fails
|
||||
device_mesh = self.build_device_mesh()
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 4, device=self.device_type)
|
||||
dtensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
|
||||
# Get the current size, stride, and storage_offset
|
||||
size = dtensor.size()
|
||||
stride = dtensor.stride()
|
||||
storage_offset = dtensor.storage_offset()
|
||||
|
||||
# Call as_strided with the exact same parameters
|
||||
result = dtensor.as_strided(size, stride, storage_offset)
|
||||
|
||||
# The result should be identical to the input
|
||||
self.assertEqual(result.size(), dtensor.size())
|
||||
self.assertEqual(result.stride(), dtensor.stride())
|
||||
self.assertEqual(result.to_local(), dtensor.to_local())
|
||||
|
||||
|
||||
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DTensorMeshTest,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_heapq.py b/test/dynamo/cpython/3_13/test_heapq.py
|
||||
index 1aa8e4e2897..94315fa68b4 100644
|
||||
index 1aa8e4e2897..bc177c2943e 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_heapq.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_heapq.py
|
||||
@@ -1,3 +1,23 @@
|
||||
@ -35,7 +35,7 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
def test_py_functions(self):
|
||||
for fname in func_names:
|
||||
self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
|
||||
@@ -27,24 +47,7 @@ class TestModules(TestCase):
|
||||
@@ -27,24 +47,12 @@ class TestModules(TestCase):
|
||||
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
|
||||
|
||||
|
||||
@ -46,12 +46,15 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
- # However, doctest can't easily find all docstrings in the module (loading
|
||||
- # it through import_fresh_module seems to confuse it), so we specifically
|
||||
- # create a finder which returns the doctests from the merge method.
|
||||
-
|
||||
+@torch._dynamo.disable
|
||||
+def randrange(*args):
|
||||
+ return random.randrange(*args)
|
||||
|
||||
- class HeapqMergeDocTestFinder:
|
||||
- def find(self, *args, **kwargs):
|
||||
- dtf = doctest.DocTestFinder()
|
||||
- return dtf.find(py_heapq.merge)
|
||||
-
|
||||
|
||||
- tests.addTests(doctest.DocTestSuite(py_heapq,
|
||||
- test_finder=HeapqMergeDocTestFinder()))
|
||||
- return tests
|
||||
@ -61,7 +64,155 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
|
||||
def test_push_pop(self):
|
||||
# 1) Push 256 random numbers and pop them off, verifying all's OK.
|
||||
@@ -264,12 +267,12 @@ class TestHeap:
|
||||
@@ -52,7 +60,8 @@ class TestHeap:
|
||||
data = []
|
||||
self.check_invariant(heap)
|
||||
for i in range(256):
|
||||
- item = random.random()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ item = random.random()
|
||||
data.append(item)
|
||||
self.module.heappush(heap, item)
|
||||
self.check_invariant(heap)
|
||||
@@ -83,14 +92,16 @@ class TestHeap:
|
||||
|
||||
def test_heapify(self):
|
||||
for size in list(range(30)) + [20000]:
|
||||
- heap = [random.random() for dummy in range(size)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ heap = [random.random() for dummy in range(size)]
|
||||
self.module.heapify(heap)
|
||||
self.check_invariant(heap)
|
||||
|
||||
self.assertRaises(TypeError, self.module.heapify, None)
|
||||
|
||||
def test_naive_nbest(self):
|
||||
- data = [random.randrange(2000) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [randrange(2000) for i in range(1000)]
|
||||
heap = []
|
||||
for item in data:
|
||||
self.module.heappush(heap, item)
|
||||
@@ -113,7 +124,8 @@ class TestHeap:
|
||||
# heap instead of a min heap, it could go faster still via
|
||||
# heapify'ing all of data (linear time), then doing 10 heappops
|
||||
# (10 log-time steps).
|
||||
- data = [random.randrange(2000) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@@ -126,7 +138,8 @@ class TestHeap:
|
||||
self.assertRaises(IndexError, self.module.heapreplace, [], None)
|
||||
|
||||
def test_nbest_with_pushpop(self):
|
||||
- data = [random.randrange(2000) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@@ -163,8 +176,9 @@ class TestHeap:
|
||||
def test_heapsort(self):
|
||||
# Exercise everything with repeated heapsort checks
|
||||
for trial in range(100):
|
||||
- size = random.randrange(50)
|
||||
- data = [random.randrange(25) for i in range(size)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ size = randrange(50)
|
||||
+ data = [randrange(25) for i in range(size)]
|
||||
if trial & 1: # Half of the time, use heapify
|
||||
heap = data[:]
|
||||
self.module.heapify(heap)
|
||||
@@ -177,12 +191,13 @@ class TestHeap:
|
||||
|
||||
def test_merge(self):
|
||||
inputs = []
|
||||
- for i in range(random.randrange(25)):
|
||||
- row = []
|
||||
- for j in range(random.randrange(100)):
|
||||
- tup = random.choice('ABC'), random.randrange(-500, 500)
|
||||
- row.append(tup)
|
||||
- inputs.append(row)
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ for i in range(randrange(25)):
|
||||
+ row = []
|
||||
+ for j in range(randrange(100)):
|
||||
+ tup = random.choice('ABC'), randrange(-500, 500)
|
||||
+ row.append(tup)
|
||||
+ inputs.append(row)
|
||||
|
||||
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
|
||||
for reverse in [False, True]:
|
||||
@@ -209,12 +224,14 @@ class TestHeap:
|
||||
list(self.module.merge(iterable(), iterable()))
|
||||
|
||||
def test_merge_stability(self):
|
||||
- class Int(int):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Int(int):
|
||||
+ pass
|
||||
inputs = [[], [], [], []]
|
||||
for i in range(20000):
|
||||
- stream = random.randrange(4)
|
||||
- x = random.randrange(500)
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ stream = randrange(4)
|
||||
+ x = randrange(500)
|
||||
obj = Int(x)
|
||||
obj.pair = (x, stream)
|
||||
inputs[stream].append(obj)
|
||||
@@ -224,7 +241,8 @@ class TestHeap:
|
||||
self.assertEqual(result, sorted(result))
|
||||
|
||||
def test_nsmallest(self):
|
||||
- data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nsmallest(n, data)),
|
||||
@@ -233,7 +251,8 @@ class TestHeap:
|
||||
sorted(data, key=f)[:n])
|
||||
|
||||
def test_nlargest(self):
|
||||
- data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nlargest(n, data)),
|
||||
@@ -248,28 +267,29 @@ class TestHeap:
|
||||
data = [comp(x) for x in data]
|
||||
self.module.heapify(data)
|
||||
return [self.module.heappop(data).x for i in range(len(data))]
|
||||
- class LT:
|
||||
- def __init__(self, x):
|
||||
- self.x = x
|
||||
- def __lt__(self, other):
|
||||
- return self.x > other.x
|
||||
- class LE:
|
||||
- def __init__(self, x):
|
||||
- self.x = x
|
||||
- def __le__(self, other):
|
||||
- return self.x >= other.x
|
||||
- data = [random.random() for i in range(100)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class LT:
|
||||
+ def __init__(self, x):
|
||||
+ self.x = x
|
||||
+ def __lt__(self, other):
|
||||
+ return self.x > other.x
|
||||
+ class LE:
|
||||
+ def __init__(self, x):
|
||||
+ self.x = x
|
||||
+ def __le__(self, other):
|
||||
+ return self.x >= other.x
|
||||
+ data = [random.random() for i in range(100)]
|
||||
target = sorted(data, reverse=True)
|
||||
self.assertEqual(hsort(data, LT), target)
|
||||
self.assertRaises(TypeError, data, LE)
|
||||
|
||||
|
||||
@ -76,7 +227,7 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
module = c_heapq
|
||||
|
||||
|
||||
@@ -374,7 +377,7 @@ class SideEffectLT:
|
||||
@@ -374,7 +394,7 @@ class SideEffectLT:
|
||||
return self.value < other.value
|
||||
|
||||
|
||||
@ -85,7 +236,48 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
|
||||
def test_non_sequence(self):
|
||||
for f in (self.module.heapify, self.module.heappop):
|
||||
@@ -464,13 +467,13 @@ class TestErrorHandling:
|
||||
@@ -435,10 +455,11 @@ class TestErrorHandling:
|
||||
def test_comparison_operator_modifiying_heap(self):
|
||||
# See bpo-39421: Strong references need to be taken
|
||||
# when comparing objects as they can alter the heap
|
||||
- class EvilClass(int):
|
||||
- def __lt__(self, o):
|
||||
- heap.clear()
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class EvilClass(int):
|
||||
+ def __lt__(self, o):
|
||||
+ heap.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
heap = []
|
||||
self.module.heappush(heap, EvilClass(0))
|
||||
@@ -446,15 +467,16 @@ class TestErrorHandling:
|
||||
|
||||
def test_comparison_operator_modifiying_heap_two_heaps(self):
|
||||
|
||||
- class h(int):
|
||||
- def __lt__(self, o):
|
||||
- list2.clear()
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class h(int):
|
||||
+ def __lt__(self, o):
|
||||
+ list2.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
- class g(int):
|
||||
- def __lt__(self, o):
|
||||
- list1.clear()
|
||||
- return NotImplemented
|
||||
+ class g(int):
|
||||
+ def __lt__(self, o):
|
||||
+ list1.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
list1, list2 = [], []
|
||||
|
||||
@@ -464,13 +486,13 @@ class TestErrorHandling:
|
||||
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
|
||||
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
|
||||
|
||||
|
||||
@ -47,6 +47,11 @@ class TestModules(__TestCase):
|
||||
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def randrange(*args):
|
||||
return random.randrange(*args)
|
||||
|
||||
|
||||
class _TestHeap:
|
||||
|
||||
def test_push_pop(self):
|
||||
@ -55,7 +60,8 @@ class _TestHeap:
|
||||
data = []
|
||||
self.check_invariant(heap)
|
||||
for i in range(256):
|
||||
item = random.random()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
item = random.random()
|
||||
data.append(item)
|
||||
self.module.heappush(heap, item)
|
||||
self.check_invariant(heap)
|
||||
@ -86,14 +92,16 @@ class _TestHeap:
|
||||
|
||||
def test_heapify(self):
|
||||
for size in list(range(30)) + [20000]:
|
||||
heap = [random.random() for dummy in range(size)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
heap = [random.random() for dummy in range(size)]
|
||||
self.module.heapify(heap)
|
||||
self.check_invariant(heap)
|
||||
|
||||
self.assertRaises(TypeError, self.module.heapify, None)
|
||||
|
||||
def test_naive_nbest(self):
|
||||
data = [random.randrange(2000) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [randrange(2000) for i in range(1000)]
|
||||
heap = []
|
||||
for item in data:
|
||||
self.module.heappush(heap, item)
|
||||
@ -116,7 +124,8 @@ class _TestHeap:
|
||||
# heap instead of a min heap, it could go faster still via
|
||||
# heapify'ing all of data (linear time), then doing 10 heappops
|
||||
# (10 log-time steps).
|
||||
data = [random.randrange(2000) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@ -129,7 +138,8 @@ class _TestHeap:
|
||||
self.assertRaises(IndexError, self.module.heapreplace, [], None)
|
||||
|
||||
def test_nbest_with_pushpop(self):
|
||||
data = [random.randrange(2000) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@ -166,8 +176,9 @@ class _TestHeap:
|
||||
def test_heapsort(self):
|
||||
# Exercise everything with repeated heapsort checks
|
||||
for trial in range(100):
|
||||
size = random.randrange(50)
|
||||
data = [random.randrange(25) for i in range(size)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
size = randrange(50)
|
||||
data = [randrange(25) for i in range(size)]
|
||||
if trial & 1: # Half of the time, use heapify
|
||||
heap = data[:]
|
||||
self.module.heapify(heap)
|
||||
@ -180,12 +191,13 @@ class _TestHeap:
|
||||
|
||||
def test_merge(self):
|
||||
inputs = []
|
||||
for i in range(random.randrange(25)):
|
||||
row = []
|
||||
for j in range(random.randrange(100)):
|
||||
tup = random.choice('ABC'), random.randrange(-500, 500)
|
||||
row.append(tup)
|
||||
inputs.append(row)
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
for i in range(randrange(25)):
|
||||
row = []
|
||||
for j in range(randrange(100)):
|
||||
tup = random.choice('ABC'), randrange(-500, 500)
|
||||
row.append(tup)
|
||||
inputs.append(row)
|
||||
|
||||
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
|
||||
for reverse in [False, True]:
|
||||
@ -212,12 +224,14 @@ class _TestHeap:
|
||||
list(self.module.merge(iterable(), iterable()))
|
||||
|
||||
def test_merge_stability(self):
|
||||
class Int(int):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Int(int):
|
||||
pass
|
||||
inputs = [[], [], [], []]
|
||||
for i in range(20000):
|
||||
stream = random.randrange(4)
|
||||
x = random.randrange(500)
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
stream = randrange(4)
|
||||
x = randrange(500)
|
||||
obj = Int(x)
|
||||
obj.pair = (x, stream)
|
||||
inputs[stream].append(obj)
|
||||
@ -227,7 +241,8 @@ class _TestHeap:
|
||||
self.assertEqual(result, sorted(result))
|
||||
|
||||
def test_nsmallest(self):
|
||||
data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nsmallest(n, data)),
|
||||
@ -236,7 +251,8 @@ class _TestHeap:
|
||||
sorted(data, key=f)[:n])
|
||||
|
||||
def test_nlargest(self):
|
||||
data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nlargest(n, data)),
|
||||
@ -251,17 +267,18 @@ class _TestHeap:
|
||||
data = [comp(x) for x in data]
|
||||
self.module.heapify(data)
|
||||
return [self.module.heappop(data).x for i in range(len(data))]
|
||||
class LT:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __lt__(self, other):
|
||||
return self.x > other.x
|
||||
class LE:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __le__(self, other):
|
||||
return self.x >= other.x
|
||||
data = [random.random() for i in range(100)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class LT:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __lt__(self, other):
|
||||
return self.x > other.x
|
||||
class LE:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __le__(self, other):
|
||||
return self.x >= other.x
|
||||
data = [random.random() for i in range(100)]
|
||||
target = sorted(data, reverse=True)
|
||||
self.assertEqual(hsort(data, LT), target)
|
||||
self.assertRaises(TypeError, data, LE)
|
||||
@ -438,10 +455,11 @@ class _TestErrorHandling:
|
||||
def test_comparison_operator_modifiying_heap(self):
|
||||
# See bpo-39421: Strong references need to be taken
|
||||
# when comparing objects as they can alter the heap
|
||||
class EvilClass(int):
|
||||
def __lt__(self, o):
|
||||
heap.clear()
|
||||
return NotImplemented
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class EvilClass(int):
|
||||
def __lt__(self, o):
|
||||
heap.clear()
|
||||
return NotImplemented
|
||||
|
||||
heap = []
|
||||
self.module.heappush(heap, EvilClass(0))
|
||||
@ -449,15 +467,16 @@ class _TestErrorHandling:
|
||||
|
||||
def test_comparison_operator_modifiying_heap_two_heaps(self):
|
||||
|
||||
class h(int):
|
||||
def __lt__(self, o):
|
||||
list2.clear()
|
||||
return NotImplemented
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class h(int):
|
||||
def __lt__(self, o):
|
||||
list2.clear()
|
||||
return NotImplemented
|
||||
|
||||
class g(int):
|
||||
def __lt__(self, o):
|
||||
list1.clear()
|
||||
return NotImplemented
|
||||
class g(int):
|
||||
def __lt__(self, o):
|
||||
list1.clear()
|
||||
return NotImplemented
|
||||
|
||||
list1, list2 = [], []
|
||||
|
||||
|
||||
@ -427,17 +427,29 @@ from user code:
|
||||
optree.tree_flatten_with_path(d)
|
||||
return torch.sin(x)
|
||||
|
||||
def post_munge(s):
|
||||
s = re.sub(
|
||||
r"optree\.\S*\.flatten_with_path",
|
||||
"optree.<path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
return re.sub(
|
||||
r"qualname: \S*flatten_with_path",
|
||||
"qualname: <path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
|
||||
fn(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = next(iter(counters["graph_break"].keys()))
|
||||
self.assertExpectedInline(
|
||||
first_graph_break,
|
||||
post_munge(first_graph_break),
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
|
||||
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||
)
|
||||
|
||||
@ -5241,6 +5241,63 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.randn(1)
|
||||
self.assertEqual(opt_mod(x), x + 1)
|
||||
|
||||
def test_full_with_tensor_fill_value(self):
|
||||
"""Test that torch.full works correctly with dynamic tensor fill_value"""
|
||||
|
||||
# Test with tensor fill_value (the bug case)
|
||||
def func_tensor(x):
|
||||
return torch.full((2,), x, dtype=torch.float64)
|
||||
|
||||
func_compiled = torch.compile(func_tensor)
|
||||
|
||||
# Test with different values
|
||||
x1 = torch.tensor(5.0, dtype=torch.float64)
|
||||
x2 = torch.tensor(10.0, dtype=torch.float64)
|
||||
|
||||
result1 = func_compiled(x1)
|
||||
expected1 = torch.full((2,), x1, dtype=torch.float64)
|
||||
self.assertEqual(result1, expected1)
|
||||
|
||||
# This is where the bug occurred - second call reused first value
|
||||
result2 = func_compiled(x2)
|
||||
expected2 = torch.full((2,), x2, dtype=torch.float64)
|
||||
self.assertEqual(result2, expected2)
|
||||
|
||||
# Test with different dtypes
|
||||
for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]:
|
||||
|
||||
def func_typed(x):
|
||||
return torch.full((3,), x, dtype=dtype)
|
||||
|
||||
func_typed_compiled = torch.compile(func_typed)
|
||||
x_typed = torch.tensor(7, dtype=dtype)
|
||||
result = func_typed_compiled(x_typed)
|
||||
expected = torch.full((3,), x_typed, dtype=dtype)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test with non-tensor fill_value (scalar) to ensure we didn't break existing behavior
|
||||
def func_scalar(size):
|
||||
return torch.full((size,), 42.0, dtype=torch.float32)
|
||||
|
||||
func_scalar_compiled = torch.compile(func_scalar)
|
||||
|
||||
result_scalar = func_scalar_compiled(5)
|
||||
expected_scalar = torch.full((5,), 42.0, dtype=torch.float32)
|
||||
self.assertEqual(result_scalar, expected_scalar)
|
||||
|
||||
# Test with different scalar values
|
||||
def func_scalar_param():
|
||||
# Test multiple calls with different hardcoded scalar values
|
||||
a = torch.full((2,), 3.14, dtype=torch.float32)
|
||||
b = torch.full((2,), 2.71, dtype=torch.float32)
|
||||
return a, b
|
||||
|
||||
func_scalar_param_compiled = torch.compile(func_scalar_param)
|
||||
result_a, result_b = func_scalar_param_compiled()
|
||||
|
||||
self.assertEqual(result_a, torch.full((2,), 3.14, dtype=torch.float32))
|
||||
self.assertEqual(result_b, torch.full((2,), 2.71, dtype=torch.float32))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(FunctionTests)
|
||||
instantiate_parametrized_tests(DefaultsTests)
|
||||
|
||||
@ -69,6 +69,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
constrain_unify,
|
||||
ConstraintViolationError,
|
||||
expect_true,
|
||||
guard_or_false,
|
||||
guard_size_oblivious,
|
||||
ShapeEnv,
|
||||
)
|
||||
@ -100,7 +101,6 @@ from torch.testing._internal.common_utils import (
|
||||
wrapDeterministicFlagAPITest,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.logging_utils import logs_to_string
|
||||
|
||||
|
||||
pytree_modules = {
|
||||
@ -13636,6 +13636,74 @@ instantiate_device_type_tests(
|
||||
)
|
||||
|
||||
|
||||
class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
|
||||
def test_symbool_tensor_mul(self):
|
||||
def symbool_mul_fn(x_bool, sentinel):
|
||||
result = x_bool * sentinel
|
||||
return result
|
||||
|
||||
x_true = torch.tensor([True], device="cuda")
|
||||
x_false = torch.tensor([False], device="cuda")
|
||||
sentinel = torch.tensor(2.0, requires_grad=True, device="cuda")
|
||||
eager_result_true = symbool_mul_fn(x_true, sentinel)
|
||||
eager_result_false = symbool_mul_fn(x_false, sentinel)
|
||||
compiled_fn = torch.compile(symbool_mul_fn, fullgraph=True, dynamic=True)
|
||||
compiled_result_true = compiled_fn(x_true, sentinel)
|
||||
compiled_result_false = compiled_fn(x_false, sentinel)
|
||||
self.assertEqual(eager_result_true, compiled_result_true)
|
||||
self.assertEqual(eager_result_false, compiled_result_false)
|
||||
self.assertEqual(compiled_result_true.item(), 2.0)
|
||||
self.assertEqual(compiled_result_false.item(), 0.0)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
|
||||
def test_symbool_guard_or_false(self):
|
||||
def symbool_guard_fn(a_bool_tensor, b):
|
||||
u0 = a_bool_tensor.item()
|
||||
# Make sure guard_or_false still handles SymBool produced by .item()
|
||||
if guard_or_false(u0):
|
||||
return b * 10
|
||||
else:
|
||||
return b * 100
|
||||
|
||||
compiled_guard_fn = torch.compile(
|
||||
symbool_guard_fn, backend="eager", dynamic=True
|
||||
)
|
||||
a_true = torch.tensor(True, device="cuda")
|
||||
a_false = torch.tensor(False, device="cuda")
|
||||
b = torch.randn(6, device="cuda")
|
||||
eager_res_true = symbool_guard_fn(a_true, b)
|
||||
compiled_res_true = compiled_guard_fn(a_true, b)
|
||||
self.assertEqual(eager_res_true, compiled_res_true)
|
||||
eager_res_false = symbool_guard_fn(a_false, b)
|
||||
compiled_res_false = compiled_guard_fn(a_false, b)
|
||||
self.assertEqual(eager_res_false, compiled_res_false)
|
||||
self.assertEqual(compiled_res_true, b * 10)
|
||||
self.assertEqual(compiled_res_false, b * 100)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
|
||||
def test_symbool_tensor_mul_does_not_fail(self):
|
||||
def fuzzed_program(arg_0, sentinel):
|
||||
var_node_2 = arg_0
|
||||
var_node_1 = torch.squeeze(var_node_2)
|
||||
var_node_0 = var_node_1.item()
|
||||
result = var_node_0 * sentinel
|
||||
if result.is_complex():
|
||||
result = result.real
|
||||
return result
|
||||
|
||||
sentinel = torch.tensor(1.0, requires_grad=True, device="cuda")
|
||||
arg_0 = torch.tensor([True], dtype=torch.bool, device="cuda")
|
||||
args = (arg_0,) + (sentinel,)
|
||||
try:
|
||||
compiled_program = torch.compile(
|
||||
fuzzed_program, fullgraph=True, dynamic=True
|
||||
)
|
||||
compiled_program(*args)
|
||||
except Exception as e:
|
||||
self.fail(f"torch.compile failed with error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
||||
@ -1000,6 +1000,18 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.exit_stack.close()
|
||||
super().tearDown()
|
||||
|
||||
def test_compiled_module_truthiness(self):
|
||||
# Test with empty ModuleList
|
||||
original_empty = nn.ModuleList()
|
||||
compiled_empty = torch.compile(original_empty)
|
||||
self.assertEqual(bool(original_empty), bool(compiled_empty))
|
||||
self.assertFalse(bool(compiled_empty))
|
||||
# Test with non-empty ModuleList
|
||||
original_filled = nn.ModuleList([nn.Linear(10, 5)])
|
||||
compiled_filled = torch.compile(original_filled)
|
||||
self.assertEqual(bool(original_filled), bool(compiled_filled))
|
||||
self.assertTrue(bool(compiled_filled))
|
||||
|
||||
def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder):
|
||||
root = guard_manager_wrapper.root
|
||||
cloned_root = root.clone_manager(lambda x: True)
|
||||
|
||||
@ -522,6 +522,83 @@ def forward(self, args_0):
|
||||
)
|
||||
self.assertEqual(ep(*inps), MyModel()(*inps))
|
||||
|
||||
def test_dynamo_graph_capture_full_tracing_context(self) -> None:
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + x.shape[0]
|
||||
|
||||
foo = Foo()
|
||||
|
||||
def make_inputs(b: int):
|
||||
ret = (torch.randn(b, 3),)
|
||||
torch._dynamo.mark_dynamic(ret[0], 0)
|
||||
return ret
|
||||
|
||||
trace_inputs = make_inputs(2)
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
test_inputs = make_inputs(3)
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
self.assertIsNotNone(gm.meta["tracing_context"].fake_mode)
|
||||
self.assertEqual(len(gm.meta["tracing_context"].tensor_to_context), 1)
|
||||
|
||||
def test_dynamo_graph_capture_dict_keys_getitem(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * 2
|
||||
|
||||
foo = Module()
|
||||
|
||||
class BlockMask:
|
||||
def __init__(self, d):
|
||||
self.d = d
|
||||
|
||||
block_mask = BlockMask(torch.randn(4))
|
||||
|
||||
def pre_hook_function(m, input):
|
||||
block_mask.d = input[0] + 1
|
||||
return input # Return a tuple of modified inputs
|
||||
|
||||
foo.register_forward_pre_hook(pre_hook_function)
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(4),)
|
||||
|
||||
trace_inputs = make_inputs()
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
test_inputs = make_inputs()
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip("\r\n "),
|
||||
"""\
|
||||
def forward(self, args_0):
|
||||
_tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,))
|
||||
L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1)
|
||||
l_args_0_ = L_args_0_
|
||||
add = l_args_0_ + 1
|
||||
mul = l_args_0_ * 2; l_args_0_ = None
|
||||
return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""",
|
||||
)
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
|
||||
def test_dynamo_graph_capture_with_tensor_constant(self):
|
||||
outer = torch.randn(2, 3)
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
z = x + outer
|
||||
return z
|
||||
|
||||
foo = MyModel()
|
||||
|
||||
def make_inputs():
|
||||
return (torch.randn(2, 3),)
|
||||
|
||||
trace_inputs = make_inputs()
|
||||
gm = dynamo_graph_capture_for_export(foo)(*trace_inputs)
|
||||
test_inputs = make_inputs()
|
||||
self.assertEqual(gm(*test_inputs), foo(*test_inputs))
|
||||
self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers())))
|
||||
self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters())))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self):
|
||||
class DummyOp(torch.autograd.Function):
|
||||
|
||||
@ -751,6 +751,29 @@ class TestConstFold(TestCase):
|
||||
)
|
||||
self.assertIsNone(mod_folded.const_subgraph_module)
|
||||
|
||||
def test_const_fold_partial_graph(self):
|
||||
"""
|
||||
If a model graph is partially const folded,
|
||||
the non-const subgraph should be inlined back and erased.
|
||||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self, p):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, x):
|
||||
probs = torch.empty_permuted(x.shape, [0, 1])
|
||||
mask = torch.bernoulli(probs, 1 - self.p)
|
||||
return x * mask / (1 - self.p)
|
||||
|
||||
ep = torch.export.export(TestModule(0.4), (torch.randn(5, 10),))
|
||||
|
||||
mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
|
||||
ep.module(), device_for_folded_attrs="cpu"
|
||||
)
|
||||
self._verify_const_fold_mod(mod_folded)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise_on_run_directly("test/test_fx.py")
|
||||
|
||||
@ -20,8 +20,14 @@ from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
skipIf,
|
||||
skipXPUIf,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
TEST_WITH_SLOW,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase
|
||||
from torch.testing._internal.inductor_utils import IS_BIG_GPU
|
||||
|
||||
|
||||
@ -382,7 +388,11 @@ class TestAnalysis(TestCase):
|
||||
|
||||
verify_triton(comp_omni)
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@skipIf(
|
||||
(not torch.xpu.is_available()) and (not SM80OrLater),
|
||||
"Requires XPU or CUDA SM80",
|
||||
)
|
||||
@skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU")
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@parametrize(
|
||||
"maxat",
|
||||
@ -467,6 +477,7 @@ class TestAnalysis(TestCase):
|
||||
"aten::cudnn_convolution",
|
||||
"aten::convolution",
|
||||
"aten::_convolution",
|
||||
"aten::convolution_overrideable",
|
||||
)
|
||||
)
|
||||
or "conv" in name
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import tempfile
|
||||
from threading import Event
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
raise_testexc,
|
||||
SubprocException,
|
||||
@ -16,9 +17,12 @@ from torch.testing._internal.inductor_utils import HAS_CPU
|
||||
|
||||
|
||||
class TestCompileWorker(TestCase):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size)
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_basic_jobs(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
b = pool.submit(operator.sub, 100, 1)
|
||||
@ -29,7 +33,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_exception(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
a = pool.submit(raise_testexc)
|
||||
with self.assertRaisesRegex(
|
||||
@ -42,7 +46,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_crash(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
with self.assertRaises(Exception):
|
||||
a = pool.submit(os._exit, 1)
|
||||
@ -58,7 +62,7 @@ class TestCompileWorker(TestCase):
|
||||
|
||||
@skipIfWindows(msg="pass_fds not supported on Windows.")
|
||||
def test_quiesce(self):
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
a = pool.submit(operator.add, 100, 1)
|
||||
pool.quiesce()
|
||||
@ -75,7 +79,7 @@ class TestCompileWorker(TestCase):
|
||||
os.environ["ROLE_RANK"] = "0"
|
||||
with tempfile.NamedTemporaryFile(delete=True) as temp_log:
|
||||
os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name
|
||||
pool = SubprocPool(2)
|
||||
pool = self.make_pool(2)
|
||||
try:
|
||||
pool.submit(operator.add, 100, 1)
|
||||
self.assertEqual(os.path.exists(temp_log.name), True)
|
||||
@ -83,6 +87,12 @@ class TestCompileWorker(TestCase):
|
||||
pool.shutdown()
|
||||
|
||||
|
||||
@config.patch("quiesce_async_compile_time", 0.1)
|
||||
class TestCompileWorkerWithTimer(TestCompileWorker):
|
||||
def make_pool(self, size):
|
||||
return SubprocPool(size, quiesce=True)
|
||||
|
||||
|
||||
class TestTimer(TestCase):
|
||||
def test_basics(self):
|
||||
done = Event()
|
||||
|
||||
154
test/inductor/test_cutedsl_grouped_mm.py
Normal file
154
test/inductor/test_cutedsl_grouped_mm.py
Normal file
@ -0,0 +1,154 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import ensure_cute_available
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
|
||||
"CuTeDSL library or Blackwell device not available",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestCuTeDSLGroupedGemm(InductorTestCase):
|
||||
def _get_inputs(
|
||||
self,
|
||||
group_size: int,
|
||||
M_hint: int,
|
||||
K: int,
|
||||
N: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
alignment: int = 16,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# --- Random, tile-aligned M sizes ---
|
||||
M_sizes = (
|
||||
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
|
||||
* alignment
|
||||
)
|
||||
|
||||
M_total = torch.sum(M_sizes).item()
|
||||
|
||||
# --- Construct input tensors ---
|
||||
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
|
||||
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
# --- Build offsets (no leading zero, strictly increasing) ---
|
||||
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
|
||||
|
||||
return (A, B, offsets)
|
||||
|
||||
@parametrize("group_size", (2, 8))
|
||||
@parametrize("M_hint", (256, 1024))
|
||||
@parametrize("K", (64, 128))
|
||||
@parametrize("N", (128, 256))
|
||||
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# Eager execution
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# Test with Cute backend
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
|
||||
@parametrize("layout_B", ("contiguous", "broadcasted"))
|
||||
def test_grouped_gemm_assorted_layouts(
|
||||
self,
|
||||
layout_A: str,
|
||||
layout_B: str,
|
||||
):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
G, K, N = 8, 64, 128
|
||||
M_sizes = [128] * G
|
||||
sum_M = sum(M_sizes)
|
||||
offsets = torch.tensor(
|
||||
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
|
||||
A = A_base
|
||||
|
||||
if layout_A == "offset":
|
||||
# allocate bigger buffer than needed, use nonzero storage offset
|
||||
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
|
||||
offset = 128 # skip first 128 elements
|
||||
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
|
||||
elif layout_A == "padded":
|
||||
# simulate row pitch > K (row_stride = K + pad)
|
||||
row_pitch = K + 8
|
||||
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
|
||||
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
|
||||
elif layout_A == "view":
|
||||
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
|
||||
A = A_storage.view(sum_M, K)
|
||||
assert A._base is not None
|
||||
assert A.shape == (sum_M, K)
|
||||
|
||||
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
if layout_B == "broadcasted":
|
||||
# Broadcast B across groups (zero stride along G)
|
||||
B = B[0].expand(G, K, N)
|
||||
assert B.stride(0) == 0
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# --- eager ---
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# --- compiled (CUTE backend) ---
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -15,9 +15,8 @@ from torch.testing._internal.common_utils import (
|
||||
is_navi3_arch,
|
||||
parametrize,
|
||||
patch_test_members,
|
||||
TEST_XPU,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU_AND_TRITON
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
|
||||
|
||||
@ -61,11 +60,6 @@ class TestDecomposeAddMM(torch.nn.Module):
|
||||
|
||||
|
||||
@requires_gpu
|
||||
@unittest.skipIf(
|
||||
TEST_XPU,
|
||||
"Intel GPU has not enabled decompose_mem_bound_mm PASS in "
|
||||
"torch/_inductor/fx_passes/decompose_mem_bound_mm.py",
|
||||
)
|
||||
@torch._inductor.config.patch(
|
||||
post_grad_fusion_options={
|
||||
"decompose_mm_pass": {},
|
||||
@ -144,7 +138,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_bmm"],
|
||||
expected_val,
|
||||
@ -155,7 +149,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 3 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
expected_val = 3 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_bmm"],
|
||||
expected_val,
|
||||
@ -204,7 +198,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
if has_bias:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_addmm"],
|
||||
@ -259,7 +253,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
if has_bias:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_addmm"],
|
||||
@ -304,7 +298,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"],
|
||||
expected_val,
|
||||
@ -316,7 +310,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
|
||||
expected_val,
|
||||
@ -374,7 +368,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"],
|
||||
expected_val,
|
||||
@ -386,7 +380,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_mm"] - decompose_mm_fwd,
|
||||
expected_val,
|
||||
@ -410,7 +404,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
self.compare_pred(module, traced, input)
|
||||
|
||||
expected_val = 1 if should_decompose and HAS_CUDA_AND_TRITON else 0
|
||||
expected_val = 1 if should_decompose and HAS_GPU_AND_TRITON else 0
|
||||
if has_bias:
|
||||
self.assertEqual(
|
||||
counters["inductor"]["decompose_addmm"],
|
||||
@ -424,7 +418,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
self.compare_gradients(module, traced)
|
||||
|
||||
expected_val = 0
|
||||
if HAS_CUDA_AND_TRITON:
|
||||
if HAS_GPU_AND_TRITON:
|
||||
expected_val = 1 if has_bias else 2
|
||||
|
||||
self.assertEqual(
|
||||
@ -447,12 +441,8 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
_, code = run_and_get_code(foo, input1, input2)
|
||||
|
||||
if GPU_TYPE == "xpu":
|
||||
# only 1 kernel generated on the XPU stack
|
||||
FileCheck().check_count(".run(", 1, exactly=True).run(code[0])
|
||||
else:
|
||||
# two kernels generated
|
||||
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
|
||||
# two kernels generated
|
||||
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])
|
||||
|
||||
def test_check_device(self):
|
||||
m = 5
|
||||
@ -462,7 +452,7 @@ class TestDecomposeMemMM(TestCase):
|
||||
|
||||
input1 = torch.randn(m, k, device=GPU_TYPE)
|
||||
input2 = torch.randn(k, n, device=GPU_TYPE)
|
||||
self.assertTrue(check_device(input1, input2))
|
||||
self.assertTrue(check_device(input1, input2, device=GPU_TYPE))
|
||||
self.assertFalse(check_device(input1, input2, device="cpu"))
|
||||
|
||||
input1 = torch.randn(m, k)
|
||||
|
||||
@ -806,8 +806,6 @@ class AOTFxirTestCase(InductorTestCase):
|
||||
def check(
|
||||
self, model, inp, dynamic_shapes=None, strict=False
|
||||
) -> torch.fx.GraphModule:
|
||||
if self.device == "xpu":
|
||||
raise unittest.SkipTest("The feature AOTFxir not currently ready for XPU")
|
||||
with torch.no_grad():
|
||||
ep = torch.export.export(
|
||||
model, inp, dynamic_shapes=dynamic_shapes, strict=strict
|
||||
|
||||
@ -500,8 +500,13 @@ class PaddingTest(TestCaseBase):
|
||||
forward_wrapper = wrapper_codes[0]
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
if bias:
|
||||
# addmm -> mm + bias and bias is fused with softmax
|
||||
softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)"
|
||||
else:
|
||||
softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)"
|
||||
self.assertTrue(
|
||||
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
|
||||
softmax_load_str in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
||||
@ -1826,9 +1826,14 @@ def run_test_module(
|
||||
test_name = test.name
|
||||
|
||||
# Printing the date here can help diagnose which tests are slow
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
|
||||
start = time.perf_counter()
|
||||
print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]")
|
||||
handler = CUSTOM_HANDLERS.get(test_name, run_test)
|
||||
return_code = handler(test, test_directory, options)
|
||||
end = time.perf_counter()
|
||||
print_to_stderr(
|
||||
f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min"
|
||||
)
|
||||
assert isinstance(return_code, int) and not isinstance(return_code, bool), (
|
||||
f"While running {str(test)} got non integer return code {return_code}"
|
||||
)
|
||||
|
||||
@ -35,7 +35,6 @@ from torch.cuda._memory_viz import (
|
||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_create_scaling_case,
|
||||
HAS_WORKING_NVML,
|
||||
SM70OrLater,
|
||||
TEST_CUDNN,
|
||||
TEST_MULTIGPU,
|
||||
@ -4804,7 +4803,6 @@ print(torch.cuda.get_allocator_backend())
|
||||
def test_temperature(self):
|
||||
self.assertTrue(0 <= torch.cuda.temperature() <= 150)
|
||||
|
||||
@unittest.skipIf(not HAS_WORKING_NVML, "pynvml availble but broken")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "flaky for AMD gpu")
|
||||
@unittest.skipIf(not TEST_PYNVML, "pynvml/amdsmi is not available")
|
||||
def test_device_memory_used(self):
|
||||
@ -7415,6 +7413,140 @@ class TestCudaDeviceParametrized(TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestFXMemoryProfiler(TestCase):
|
||||
"""Tests for memory profiler augmentation with original stack traces."""
|
||||
|
||||
def collect_frames(
|
||||
self, augmented_snapshot, collect_device_traces=True, collect_segments=True
|
||||
):
|
||||
"""Collects all frames that has node metadata from a memory snapshot."""
|
||||
# Collect all frames with FX metadata
|
||||
fx_frames = []
|
||||
|
||||
# Check device traces for FX debug fields
|
||||
if collect_device_traces and "device_traces" in augmented_snapshot:
|
||||
for trace_list in augmented_snapshot["device_traces"]:
|
||||
for trace_entry in trace_list:
|
||||
if isinstance(trace_entry, dict) and "frames" in trace_entry:
|
||||
for frame in trace_entry["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
# Check for FX debug fields
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
|
||||
# Check segments/blocks for FX debug fields
|
||||
if collect_segments and "segments" in augmented_snapshot:
|
||||
for segment in augmented_snapshot["segments"]:
|
||||
if "blocks" in segment:
|
||||
for block in segment["blocks"]:
|
||||
if "frames" in block:
|
||||
for frame in block["frames"]:
|
||||
if isinstance(frame, dict):
|
||||
if "fx_node_op" in frame or "fx_node_name" in frame:
|
||||
fx_frames.append(frame)
|
||||
return fx_frames
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_fx_memory_profiler_augmentation(self):
|
||||
"""Test that memory snapshots are augmented with FX debug information."""
|
||||
|
||||
# Create a simple model
|
||||
class MLPModule(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.net1(x)
|
||||
b = self.relu(a)
|
||||
c = self.net2(b)
|
||||
return c
|
||||
|
||||
device = "cuda"
|
||||
mod = MLPModule(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
fx_frames = self.collect_frames(augmented_snapshot)
|
||||
if TEST_WITH_ROCM:
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
else:
|
||||
self.assertEqual(len(fx_frames), 12)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("a = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("c = self.net2(b)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("b = self.relu(a)", frame["fx_original_trace"])
|
||||
|
||||
# Test that when we have two graphs with the same src_code, they're not hashed
|
||||
# to the same metadata
|
||||
class MLPModule2(nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
torch.manual_seed(5)
|
||||
self.net1 = nn.Linear(10, 16, bias=True, device=device)
|
||||
self.relu = nn.ReLU()
|
||||
self.net2 = nn.Linear(16, 10, bias=True, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
d = self.net1(x)
|
||||
e = self.relu(d)
|
||||
f = self.net2(e)
|
||||
return f
|
||||
|
||||
mod = MLPModule2(device)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
torch.cuda.memory._record_memory_history()
|
||||
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
|
||||
result = compiled(torch.randn(10, 10, device=device))
|
||||
augmented_snapshot = torch.cuda.memory._snapshot(
|
||||
augment_with_fx_traces=True
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled=None, clear_history=True)
|
||||
|
||||
# avoid collecting segments from previous run for unit test purpose
|
||||
fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False)
|
||||
self.assertGreater(len(fx_frames), 0)
|
||||
|
||||
for frame in fx_frames:
|
||||
# Every FX frame should have both node_op and node_name
|
||||
self.assertIn("fx_node_op", frame)
|
||||
self.assertIn("fx_node_name", frame)
|
||||
self.assertIn("fx_node_target", frame)
|
||||
self.assertIn("fx_original_trace", frame)
|
||||
|
||||
self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"])
|
||||
fx_node_name = frame["fx_node_name"]
|
||||
if fx_node_name == "addmm":
|
||||
self.assertIn("d = self.net1(x)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "addmm_1":
|
||||
self.assertIn("f = self.net2(e)", frame["fx_original_trace"])
|
||||
elif fx_node_name == "relu":
|
||||
self.assertIn("e = self.relu(d)", frame["fx_original_trace"])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCuda)
|
||||
instantiate_parametrized_tests(TestCudaMallocAsync)
|
||||
instantiate_parametrized_tests(TestCompileKernel)
|
||||
|
||||
@ -771,6 +771,7 @@ class TestFX(JitTestCase):
|
||||
gm = GraphModule(tracer.root, graph)
|
||||
expected = {1: 2, 2: 3, 3: 4, 4: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
# test custom codegen
|
||||
def transform_code(code):
|
||||
@ -780,6 +781,7 @@ class TestFX(JitTestCase):
|
||||
gm.recompile()
|
||||
expected = {2: 2, 3: 3, 4: 4, 5: 5}
|
||||
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
||||
self.assertEqual(gm._prologue_start, 4)
|
||||
|
||||
def test_graph_unique_names_manual(self):
|
||||
graph: torch.fx.Graph = torch.fx.Graph()
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
from torch.nn.functional import scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||
from torch.nn.functional import pad, scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||
from torch.testing._internal.common_cuda import (
|
||||
IS_SM90,
|
||||
_get_torch_cuda_version,
|
||||
@ -107,11 +107,76 @@ def tensor_to_scale_block(
|
||||
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
|
||||
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
|
||||
scale = torch.finfo(float8_dtype).max / amax
|
||||
# if amax == 0, entire block = 0, set scale = 0 to ensure elements are
|
||||
# zero'd out correctly (and remove bad effects of / 0)
|
||||
scale[amax == 0] = 0
|
||||
|
||||
# Scale x, noting that blocks where amax == 0 are explicitly 0 now.
|
||||
x = x.mul(scale).to(float8_dtype)
|
||||
|
||||
# if amax == 0, all values in the block are 0, scale=0
|
||||
# but we need scale.reciprocal later, which breaks when scale=0...
|
||||
# So. Replace 0 -> 1 in the scale so we don't break things later.
|
||||
# Elements are already zeroed, so don't actually care what the scale
|
||||
# is, as long as it's not inf/nan.
|
||||
scale[scale == 0] = 1.
|
||||
|
||||
x = x.flatten(2, 3).flatten(0, 1)
|
||||
scale = scale.flatten(2, 3).flatten(0, 1)
|
||||
return x, scale
|
||||
|
||||
def hp_from_128x128(x_lp, x_scale):
|
||||
orig_shape = x_lp.shape
|
||||
M, K = orig_shape
|
||||
x_lp = x_lp.view(M // 128, 128, K // 128, 128)
|
||||
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
|
||||
x_hp = x_lp.to(torch.float32)
|
||||
x_hp = x_hp / x_scale
|
||||
return x_hp.reshape(orig_shape).to(torch.bfloat16)
|
||||
|
||||
def hp_to_128x128(x, x_scale):
|
||||
orig_shape = x.shape
|
||||
M, K = orig_shape
|
||||
x = x.view(M // 128, 128, K // 128, 128)
|
||||
x_scale = x_scale.unsqueeze(1).unsqueeze(-1)
|
||||
x_lp = x * x_scale
|
||||
|
||||
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
|
||||
|
||||
def hp_from_1x128(x_lp, x_scale):
|
||||
orig_shape = x_lp.shape
|
||||
x_lp = x_lp.reshape(x_lp.shape[0], x_lp.shape[-1] // 128, 128)
|
||||
x_hp = x_lp.to(torch.float32)
|
||||
x_hp = x_hp / x_scale.unsqueeze(-1)
|
||||
return x_hp.reshape(orig_shape).to(torch.bfloat16)
|
||||
|
||||
def hp_to_1x128(x, x_scale):
|
||||
orig_shape = x.shape
|
||||
x = x.reshape(x.shape[0], x.shape[-1] // 128, 128)
|
||||
x_lp = x * x_scale.unsqueeze(-1)
|
||||
return x_lp.reshape(orig_shape).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
# cublas requires specific padding for 128x128 scales, see:
|
||||
# https://docs.nvidia.com/cuda/cublas/#element-1d-and-128x128-2d-block-scaling-for-fp8-data-types
|
||||
# Notably L = ceil_div(K, 128),
|
||||
# L4 = round_up(L, 4),
|
||||
# and then for A/B the shape must be
|
||||
# scale: [L4, ceil_div({M,N}, 128) and K/L/L4-major in memory.
|
||||
#
|
||||
# This routine pads L -> L4
|
||||
def _pad_128x128_scales(scale: torch.Tensor) -> (torch.Tensor, int):
|
||||
# scale is either [L4, ceil_div(M, 128)] or [L4, ceil_div(N, 128)], stride: [1, L4]
|
||||
# However, we get passed it as [ceil_div(M, 128), L] or [ceil_div(N, 128), L]
|
||||
# so check inner dim % 4, and pad if necessary
|
||||
pad_amount = scale.shape[-1] % 4
|
||||
|
||||
if pad_amount == 0:
|
||||
return scale, 0
|
||||
else:
|
||||
pad_amount = 4 - pad_amount
|
||||
return pad(scale, (0, pad_amount), "constant", 0), pad_amount
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
@ -144,42 +209,36 @@ def infer_scale_swizzle(mat, scale):
|
||||
] == math.ceil(mat.shape[1] // 128):
|
||||
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# if we're checking for nvfp4, need to adjust for packed-K
|
||||
K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1
|
||||
# NVFP4
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4))
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MXFP4 w/o swizzle
|
||||
if (
|
||||
(scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# MX formats
|
||||
if not torch.version.hip:
|
||||
# MXFP8 w/ swizzle
|
||||
# MX w/swizzle (NVIDIA)
|
||||
if (
|
||||
(scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4))
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4))
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
else:
|
||||
# MXFP8 w/o swizzle
|
||||
# MX w/o swizzle (AMD)
|
||||
if (
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1]
|
||||
or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0])
|
||||
(scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1]
|
||||
or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0])
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
|
||||
@ -1252,7 +1311,6 @@ class TestFP8Matmul(TestCase):
|
||||
else:
|
||||
test()
|
||||
|
||||
# Note: Removed parameterization over M,N,K from #163829 as it failed tests as-is
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
@unittest.skipIf(
|
||||
@ -1261,59 +1319,224 @@ class TestFP8Matmul(TestCase):
|
||||
)
|
||||
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
|
||||
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
|
||||
@parametrize("M,N,K", [(256, 768, 512)])
|
||||
@with_tf32_off
|
||||
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block, M, N, K):
|
||||
@parametrize("M,N,K", [
|
||||
# Nice size
|
||||
(256, 768, 512),
|
||||
# Requires padding for 128x128 scale
|
||||
(384, 128, 1280),
|
||||
# M=N=K for eyes test
|
||||
(512, 512, 512),
|
||||
])
|
||||
@parametrize("test_case", [
|
||||
"x_eye_b_eye",
|
||||
"x_ones_y_ones_calc_scales",
|
||||
"x_ones_y_ones_set_scales",
|
||||
"x_ones_y_ones_modify_scales",
|
||||
"data_random_scales_one",
|
||||
"data_random_calc_scales",
|
||||
])
|
||||
def test_scaled_mm_block_wise_numerics(self, output_dtype, lhs_block, rhs_block, M, N, K, test_case):
|
||||
"""
|
||||
subsume test_scaled_mm_vs_emulated_block_wise for random inputs, random scales,
|
||||
do some other functional tests as well.
|
||||
|
||||
# Inputs (as generated are):
|
||||
# A: [M, K]
|
||||
# B: [N, K]
|
||||
# then scales are, for the 3 combinations:
|
||||
# 1x128 x 1x128:
|
||||
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
|
||||
# Bs: [N, K // 128], stride: [1, N] -> scale.t().contiguous().t()
|
||||
# 1x128 x 128x128
|
||||
# L4 = round_up(K // 128, 4)
|
||||
# As: [M, K // 128], stride: [1, M] -> scale.t().contiguous().t()
|
||||
# Bs: [L4, N // 128], stride: [1, L4] -> scale.t()
|
||||
# 128x128 x 1x128
|
||||
# L4 = round_up(K // 128, 4)
|
||||
# As: [L4, M // 128], stride: [1, L4]
|
||||
# Bs: [N, K // 128], stride: [1, N]
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype).pow(3)
|
||||
def _adjust_lhs_scale(x_fp8, x_scales, lhs_block):
|
||||
M, K = x_fp8.shape
|
||||
x_scales_original = x_scales.clone()
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
|
||||
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
|
||||
x_hp = hp_from_1x128(x_fp8, x_scales_original)
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
x_scales, pad_amount = _pad_128x128_scales(x_scales)
|
||||
# scales in [M // 128, L4] -> [L4, M // 128]
|
||||
x_scales = x_scales.t()
|
||||
x_hp = hp_from_128x128(x_fp8, x_scales_original)
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
return x_hp, lhs_recipe, x_scales, x_scales_original
|
||||
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
def _adjust_rhs_scale(y_fp8, y_scales, rhs_block):
|
||||
N, K = y_fp8.shape
|
||||
y_scales_original = y_scales.clone()
|
||||
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
|
||||
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
|
||||
y_hp = hp_from_1x128(y_fp8, y_scales_original)
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
y_scales, pad_amount = _pad_128x128_scales(y_scales)
|
||||
# Scale in [N // 128, L4] -> [L4, N // 128]
|
||||
y_scales = y_scales.t()
|
||||
y_hp = hp_from_128x128(y_fp8, y_scales_original)
|
||||
|
||||
return y_hp, rhs_recipe, y_scales, y_scales_original
|
||||
|
||||
def _build_lhs(x, lhs_block):
|
||||
M, K = x.shape
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
x_scales_original = x_scales
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
|
||||
return x_hp, x_recipe, x_fp8, x_scales, x_scales_original
|
||||
|
||||
def _build_rhs(y, rhs_block):
|
||||
N, K = y.shape
|
||||
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
|
||||
return y_hp, y_recipe, y_fp8, y_scales, y_scales_original
|
||||
|
||||
def _run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original):
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8.t(),
|
||||
scale_a=x_scales.reciprocal(),
|
||||
scale_recipe_a=x_recipe,
|
||||
# Note: No more .t() on scale_b, not necessary.
|
||||
scale_b=y_scales.reciprocal(),
|
||||
scale_recipe_b=y_recipe,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8,
|
||||
x_scales_original,
|
||||
y_fp8.t(),
|
||||
y_scales_original.t(),
|
||||
output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_emulated.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated.to(output_dtype), atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
def _build_constant_scale(t, block, val):
|
||||
M, K = t.shape
|
||||
|
||||
if block == 1:
|
||||
scale_shape = M, K // 128
|
||||
else:
|
||||
scale_shape = M // 128, K // 128
|
||||
|
||||
scale = torch.full(scale_shape, val, device='cuda')
|
||||
|
||||
return scale
|
||||
|
||||
def hp_to_scaled(t, scale, block):
|
||||
if block == 1:
|
||||
return hp_to_1x128(t, scale)
|
||||
else:
|
||||
return hp_to_128x128(t, scale)
|
||||
|
||||
e4m3_type = torch.float8_e4m3fn
|
||||
|
||||
if test_case == "x_eye_b_eye":
|
||||
if M != K or M != N:
|
||||
return unittest.skip("a_eye_b_eye only defined for M = N = K")
|
||||
x = torch.eye(M, device='cuda')
|
||||
y = torch.eye(M, device='cuda')
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
elif test_case == "x_ones_y_ones_calc_scales":
|
||||
x = torch.full((M, K), 1.0, device='cuda')
|
||||
y = torch.full((N, K), 1.0, device='cuda')
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
elif test_case in ["x_ones_y_ones_set_scales", "x_ones_y_ones_modify_scales"]:
|
||||
x = torch.full((M, K), 1.0, device='cuda')
|
||||
y = torch.full((N, K), 1.0, device='cuda')
|
||||
|
||||
x_scales = _build_constant_scale(x, lhs_block, 1.)
|
||||
y_scales = _build_constant_scale(y, rhs_block, 1.)
|
||||
|
||||
if "modify" in test_case:
|
||||
x_scales[0, 0] = 4.
|
||||
y_scales[-1, -1] = 4.
|
||||
|
||||
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
|
||||
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
elif test_case == "data_random_scales_one":
|
||||
x = torch.randint(0, 255, (M, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
|
||||
y = torch.randint(0, 255, (N, K), device='cuda', dtype=torch.uint8).to(torch.bfloat16)
|
||||
|
||||
x_scales = _build_constant_scale(x, lhs_block, 1.)
|
||||
y_scales = _build_constant_scale(y, rhs_block, 1.)
|
||||
|
||||
x_fp8 = hp_to_scaled(x, x_scales, lhs_block)
|
||||
y_fp8 = hp_to_scaled(y, y_scales, rhs_block)
|
||||
|
||||
x_hp, x_recipe, x_scales, x_scales_original = _adjust_lhs_scale(x_fp8, x_scales, lhs_block)
|
||||
y_hp, y_recipe, y_scales, y_scales_original = _adjust_rhs_scale(y_fp8, y_scales, rhs_block)
|
||||
elif test_case == "data_random_calc_scales":
|
||||
# Note: Old test_scaled_mm_vs_emulated_block_wise test case
|
||||
x = torch.randn(M, K, device="cuda", dtype=output_dtype)
|
||||
y = torch.randn(N, K, device="cuda", dtype=output_dtype) * 1e-3
|
||||
|
||||
x_hp, x_recipe, x_fp8, x_scales, x_scales_original = _build_lhs(x, lhs_block)
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original = _build_lhs(y, rhs_block)
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
raise ValueError("Unknown test-case passed")
|
||||
|
||||
_run_test(x_hp, x_recipe, x_fp8, x_scales, x_scales_original,
|
||||
y_hp, y_recipe, y_fp8, y_scales, y_scales_original)
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = scaled_mm_wrap(
|
||||
x_fp8, y_fp8.t(), scale_a=x_scales.reciprocal(), scale_b=y_scales.reciprocal().t(), out_dtype=output_dtype,
|
||||
scale_recipe_a=lhs_recipe, scale_recipe_b=rhs_recipe
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
@ -1335,18 +1558,30 @@ class TestFP8Matmul(TestCase):
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
x_scales_original = x_scales
|
||||
y_scales_original = y_scales
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (x_scales.shape[0] == M and x_scales.shape[1] == K // 128), f"{x_scales.shape=}"
|
||||
assert (x_scales.stride(0) == 1 and x_scales.stride(1) in [1, M]), f"{x_scales.stride=}"
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
x_scales, pad_amount = _pad_128x128_scales(x_scales)
|
||||
# scales in [M // 128, L4] -> [L4, M // 128]
|
||||
x_scales = x_scales.t()
|
||||
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
assert (y_scales.shape[0] == N and y_scales.shape[1] == K // 128), f"{y_scales.shape=}"
|
||||
assert (y_scales.stride(0) == 1 and y_scales.stride(1) in [1, N]), f"{y_scales.stride=}"
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
y_scales, pad_amount = _pad_128x128_scales(y_scales)
|
||||
# Scale in [N // 128, L4] -> [L4, N // 128]
|
||||
y_scales = y_scales.t()
|
||||
|
||||
# Verify that actual F8 mm doesn't error
|
||||
scaled_mm_wrap(
|
||||
@ -1354,13 +1589,20 @@ class TestFP8Matmul(TestCase):
|
||||
y_fp8.t(),
|
||||
scale_a=x_scales,
|
||||
scale_recipe_a=lhs_recipe,
|
||||
scale_b=y_scales.t(),
|
||||
# Note: No more .t() on scale_b, not necessary.
|
||||
scale_b=y_scales,
|
||||
scale_recipe_b=rhs_recipe,
|
||||
out_dtype=output_dtype,
|
||||
)
|
||||
|
||||
# Verify that emulated F8 mm doesn't error
|
||||
mm_float8_emulated_block(x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype)
|
||||
mm_float8_emulated_block(
|
||||
x_fp8,
|
||||
x_scales_original,
|
||||
y_fp8.t(),
|
||||
y_scales_original.t(),
|
||||
output_dtype
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@onlyCUDA
|
||||
@ -1620,7 +1862,7 @@ class TestFP8Matmul(TestCase):
|
||||
(127, 96, 1024),
|
||||
(1025, 128, 96)
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"])
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
@ -1634,8 +1876,12 @@ class TestFP8Matmul(TestCase):
|
||||
if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0):
|
||||
raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping")
|
||||
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32)
|
||||
fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn
|
||||
BLOCK_SIZE = 16 if recipe == "nvfp4" else 32
|
||||
|
||||
if K % BLOCK_SIZE != 0:
|
||||
raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping")
|
||||
|
||||
require_exact_match = True
|
||||
approx_match_sqnr_target = 22.0
|
||||
|
||||
@ -1813,7 +2059,7 @@ class TestFP8Matmul(TestCase):
|
||||
B = B.clamp(min=min_val, max=max_val)
|
||||
B = _bfloat16_to_float4_e2m1fn_x2(B)
|
||||
|
||||
approx_match_sqnr_target = 15 if torch.version.hip else 15.8
|
||||
approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8
|
||||
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
|
||||
@ -47,11 +47,18 @@ def get_all_examples():
|
||||
"import io",
|
||||
"import itertools",
|
||||
"",
|
||||
"from typing import Any, ClassVar, Generic, List, Tuple, Union",
|
||||
"from typing_extensions import Literal, get_origin, TypeAlias",
|
||||
"T: TypeAlias = object",
|
||||
"",
|
||||
"import numpy",
|
||||
"",
|
||||
"import torch",
|
||||
"import torch.nn.functional as F",
|
||||
"",
|
||||
"from typing_extensions import ParamSpec as _ParamSpec",
|
||||
"ParamSpec = _ParamSpec",
|
||||
"",
|
||||
# for requires_grad_ example
|
||||
# NB: We are parsing this file as Python 2, so we must use
|
||||
# Python 2 type annotation syntax
|
||||
|
||||
115
test/test_xpu.py
115
test/test_xpu.py
@ -14,10 +14,8 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyXPU,
|
||||
OpDTypes,
|
||||
ops,
|
||||
skipXPUIf,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import ops_and_refs
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -74,6 +72,8 @@ _xpu_computation_ops = [
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
class TestXpu(TestCase):
|
||||
expandable_segments = False
|
||||
|
||||
def test_device_behavior(self):
|
||||
current_device = torch.xpu.current_device()
|
||||
torch.xpu.set_device(current_device)
|
||||
@ -385,56 +385,6 @@ if __name__ == "__main__":
|
||||
torch.xpu.set_rng_state(g_state0)
|
||||
self.assertEqual(2024, torch.xpu.initial_seed())
|
||||
|
||||
@onlyXPU
|
||||
@suppress_warnings
|
||||
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
|
||||
def test_compare_cpu(self, device, dtype, op):
|
||||
def to_cpu(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return arg.to(device="cpu")
|
||||
return arg
|
||||
|
||||
samples = op.reference_inputs(device, dtype)
|
||||
|
||||
for sample in samples:
|
||||
cpu_sample = sample.transform(to_cpu)
|
||||
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
|
||||
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
|
||||
|
||||
xpu_results = sample.output_process_fn_grad(xpu_results)
|
||||
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
|
||||
|
||||
# Lower tolerance because we are running this as a `@slowTest`
|
||||
# Don't want the periodic tests to fail frequently
|
||||
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@onlyXPU
|
||||
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
|
||||
def test_non_standard_bool_values(self, device, dtype, op):
|
||||
# Test boolean values other than 0x00 and 0x01 (gh-54789)
|
||||
def convert_boolean_tensors(x):
|
||||
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
|
||||
return x
|
||||
|
||||
# Map False -> 0 and True -> Random value in [2, 255]
|
||||
true_vals = torch.randint(
|
||||
2, 255, x.shape, dtype=torch.uint8, device=x.device
|
||||
)
|
||||
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
|
||||
x_int = torch.where(x, true_vals, false_vals)
|
||||
|
||||
ret = x_int.view(torch.bool)
|
||||
self.assertEqual(ret, x)
|
||||
return ret
|
||||
|
||||
for sample in op.sample_inputs(device, dtype):
|
||||
expect = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
transformed = sample.transform(convert_boolean_tensors)
|
||||
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
|
||||
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def test_serialization_array_with_storage(self):
|
||||
x = torch.randn(5, 5).xpu()
|
||||
y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
|
||||
@ -470,6 +420,8 @@ if __name__ == "__main__":
|
||||
self.assertEqual(copy.get_device(), original.get_device())
|
||||
|
||||
def test_out_of_memory(self):
|
||||
if self.expandable_segments:
|
||||
self.skipTest("Skipping OOM test for expandable segments allocator.")
|
||||
tensor = torch.zeros(1024, device="xpu") # noqa: F841
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"):
|
||||
@ -479,6 +431,8 @@ if __name__ == "__main__":
|
||||
torch.empty(1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="xpu")
|
||||
|
||||
def test_raises_oom(self):
|
||||
if self.expandable_segments:
|
||||
self.skipTest("Skipping OOM test for expandable segments allocator.")
|
||||
torch.xpu.memory.empty_cache()
|
||||
with self.assertRaises(torch.OutOfMemoryError):
|
||||
torch.empty(1024 * 1024 * 1024 * 1024, device="xpu")
|
||||
@ -591,7 +545,7 @@ if __name__ == "__main__":
|
||||
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
||||
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
||||
|
||||
@skipXPUIf(
|
||||
@unittest.skipIf(
|
||||
int(torch.version.xpu) < 20250000,
|
||||
"Test requires SYCL compiler version 2025.0.0 or newer.",
|
||||
)
|
||||
@ -639,6 +593,8 @@ if __name__ == "__main__":
|
||||
self.assertTrue(b"libsycl.so" in result)
|
||||
|
||||
def test_dlpack_conversion(self):
|
||||
if self.expandable_segments:
|
||||
self.skipTest("Skipping DLPack test for expandable segments allocator.")
|
||||
x = make_tensor((5,), dtype=torch.float32, device="xpu")
|
||||
if IS_WINDOWS and int(torch.version.xpu) < 20250000:
|
||||
with self.assertRaisesRegex(
|
||||
@ -652,7 +608,58 @@ if __name__ == "__main__":
|
||||
self.assertEqual(z, x)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
class TestXpuOps(TestCase):
|
||||
@suppress_warnings
|
||||
@ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one)
|
||||
def test_compare_cpu(self, device, dtype, op):
|
||||
def to_cpu(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return arg.to(device="cpu")
|
||||
return arg
|
||||
|
||||
samples = op.reference_inputs(device, dtype)
|
||||
|
||||
for sample in samples:
|
||||
cpu_sample = sample.transform(to_cpu)
|
||||
xpu_results = op(sample.input, *sample.args, **sample.kwargs)
|
||||
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)
|
||||
|
||||
xpu_results = sample.output_process_fn_grad(xpu_results)
|
||||
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)
|
||||
|
||||
# Lower tolerance because we are running this as a `@slowTest`
|
||||
# Don't want the periodic tests to fail frequently
|
||||
self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@ops(_xpu_computation_ops, allowed_dtypes=(torch.bool,))
|
||||
def test_non_standard_bool_values(self, device, dtype, op):
|
||||
# Test boolean values other than 0x00 and 0x01 (gh-54789)
|
||||
def convert_boolean_tensors(x):
|
||||
if not isinstance(x, torch.Tensor) or x.dtype != torch.bool:
|
||||
return x
|
||||
|
||||
# Map False -> 0 and True -> Random value in [2, 255]
|
||||
true_vals = torch.randint(
|
||||
2, 255, x.shape, dtype=torch.uint8, device=x.device
|
||||
)
|
||||
false_vals = torch.zeros((), dtype=torch.uint8, device=x.device)
|
||||
x_int = torch.where(x, true_vals, false_vals)
|
||||
|
||||
ret = x_int.view(torch.bool)
|
||||
self.assertEqual(ret, x)
|
||||
return ret
|
||||
|
||||
for sample in op.sample_inputs(device, dtype):
|
||||
expect = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
transformed = sample.transform(convert_boolean_tensors)
|
||||
actual = op(transformed.input, *transformed.args, **transformed.kwargs)
|
||||
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestXpuOps, globals(), only_for="xpu", allow_xpu=True)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
|
||||
26
test/test_xpu_expandable_segments.py
Normal file
26
test/test_xpu_expandable_segments.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Owner(s): ["module: intel"]
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
from test_xpu import TestXpu, TestXpuOpsXPU # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
|
||||
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests
|
||||
|
||||
|
||||
sys.path.remove(str(REPO_ROOT))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if torch.xpu.is_available() and not IS_WINDOWS:
|
||||
get_disabled_tests(".")
|
||||
|
||||
torch._C._accelerator_setAllocatorSettings("expandable_segments:True")
|
||||
TestXpu.expandable_segments = True
|
||||
|
||||
run_tests()
|
||||
@ -1,356 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
STABLE_SHIM_VERSION: Ensures that function declarations in stable/c/shim.h
|
||||
are properly wrapped in TORCH_FEATURE_VERSION macros corresponding to the
|
||||
current TORCH_ABI_VERSION.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
LINTER_CODE = "STABLE_SHIM_VERSION"
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
ERROR = "error"
|
||||
WARNING = "warning"
|
||||
ADVICE = "advice"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
Parses a version string into (major, minor, patch) version numbers.
|
||||
This function is copied from tools/setup_helpers/gen_version_header.py
|
||||
to ensure consistency with how PyTorch parses its version.
|
||||
|
||||
Args:
|
||||
version: Full version number string, possibly including revision / commit hash.
|
||||
|
||||
Returns:
|
||||
A tuple of (major, minor, patch) version numbers.
|
||||
"""
|
||||
# Extract version number part (i.e. toss any revision / hash parts).
|
||||
version_number_str = version
|
||||
for i in range(len(version)):
|
||||
c = version[i]
|
||||
if not (c.isdigit() or c == "."):
|
||||
version_number_str = version[:i]
|
||||
break
|
||||
|
||||
return tuple([int(n) for n in version_number_str.split(".")]) # type: ignore[return-value]
|
||||
|
||||
|
||||
def get_current_version() -> tuple[int, int]:
|
||||
"""
|
||||
Get the current PyTorch version from version.txt.
|
||||
This uses the same logic as tools/setup_helpers/gen_version_header.py
|
||||
which is used to generate torch/headeronly/version.h from version.h.in.
|
||||
|
||||
Returns (major, minor) tuple or None if not found.
|
||||
"""
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
version_file = repo_root / "version.txt"
|
||||
|
||||
if not version_file.exists():
|
||||
raise RuntimeError(
|
||||
"Could not find version.txt. This linter require version.txt to run"
|
||||
)
|
||||
|
||||
with open(version_file) as f:
|
||||
version = f.read().strip()
|
||||
major, minor, patch = parse_version(version)
|
||||
|
||||
return (major, minor)
|
||||
|
||||
|
||||
def get_added_lines(filename: str) -> set[int]:
|
||||
"""
|
||||
Get the line numbers of added lines in:
|
||||
1. Current uncommitted changes (git diff HEAD)
|
||||
2. The most recent commit (git diff HEAD~1..HEAD)
|
||||
|
||||
This ensures that even if someone commits locally before running the linter,
|
||||
we still catch version macro issues in their recent changes.
|
||||
|
||||
Returns:
|
||||
Set of line numbers (1-indexed) that are new additions.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
added_lines = set()
|
||||
|
||||
def parse_diff(diff_output: str) -> set[int]:
|
||||
"""Parse git diff output and return line numbers of added lines."""
|
||||
lines = set()
|
||||
current_line = 0
|
||||
for line in diff_output.split("\n"):
|
||||
# Unified diff format: @@ -old_start,old_count +new_start,new_count @@
|
||||
if line.startswith("@@"):
|
||||
match = re.search(r"\+(\d+)", line)
|
||||
if match:
|
||||
current_line = int(match.group(1))
|
||||
elif line.startswith("+") and not line.startswith("+++"):
|
||||
# This is an added line
|
||||
lines.add(current_line)
|
||||
current_line += 1
|
||||
elif not line.startswith("-"):
|
||||
# Context line or unchanged line
|
||||
current_line += 1
|
||||
return lines
|
||||
|
||||
try:
|
||||
# Check uncommitted changes (working directory vs HEAD)
|
||||
result = subprocess.run(
|
||||
["git", "diff", "HEAD", filename],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
added_lines.update(parse_diff(result.stdout))
|
||||
|
||||
# Also check the most recent commit (HEAD vs HEAD~1)
|
||||
# This catches cases where someone commits before running the linter
|
||||
result = subprocess.run(
|
||||
["git", "diff", "HEAD~1..HEAD", filename],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
added_lines.update(parse_diff(result.stdout))
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to get git diff information for {filename}. Error: {e}"
|
||||
) from e
|
||||
|
||||
return added_lines
|
||||
|
||||
|
||||
def parse_shim_file(filename: str) -> list[LintMessage]:
|
||||
"""
|
||||
Parse the stable/c/shim.h file and check that:
|
||||
1. All function declarations are within TORCH_FEATURE_VERSION blocks
|
||||
2. New functions added in this commit use the current version macro
|
||||
"""
|
||||
lint_messages: list[LintMessage] = []
|
||||
|
||||
# Get current version
|
||||
current_version = get_current_version()
|
||||
major, minor = current_version
|
||||
expected_version_macro = f"TORCH_VERSION_{major}_{minor}_0"
|
||||
expected_version_check = f"#if TORCH_FEATURE_VERSION >= {expected_version_macro}"
|
||||
|
||||
# Get lines that are uncommitted or added in the most recent commit
|
||||
added_lines = get_added_lines(filename)
|
||||
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Track state
|
||||
inside_version_block = False
|
||||
current_version_macro = None
|
||||
inside_extern_c = False
|
||||
|
||||
# Track ALL preprocessor conditional blocks to properly match #if/#endif pairs
|
||||
# Each element is (is_version_block, version_macro_or_none)
|
||||
preprocessor_stack: list[tuple[bool, str | None]] = []
|
||||
|
||||
# Patterns
|
||||
version_start_pattern = re.compile(
|
||||
r"#if\s+TORCH_FEATURE_VERSION\s*>=\s*(TORCH_VERSION_\d+_\d+_\d+)"
|
||||
)
|
||||
extern_c_pattern = re.compile(r'extern\s+"C"\s*{')
|
||||
extern_c_end_pattern = re.compile(r'}\s*//\s*extern\s+"C"')
|
||||
|
||||
# Function declaration patterns - looking for AOTI_TORCH_EXPORT or typedef
|
||||
function_decl_patterns = [
|
||||
re.compile(r"^\s*AOTI_TORCH_EXPORT\s+\w+"), # AOTI_TORCH_EXPORT functions
|
||||
re.compile(r"^\s*typedef\s+.*\(\*\w+\)"), # typedef function pointers
|
||||
re.compile(r"^\s*using\s+\w+\s*="), # using declarations
|
||||
]
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not stripped or stripped.startswith("//"):
|
||||
continue
|
||||
|
||||
# Check for TORCH_FEATURE_VERSION block start
|
||||
version_match = version_start_pattern.match(stripped)
|
||||
if version_match:
|
||||
version_macro = version_match.group(1)
|
||||
preprocessor_stack.append((True, version_macro))
|
||||
inside_version_block = True
|
||||
current_version_macro = version_macro
|
||||
continue
|
||||
|
||||
# Track any other #if/#ifdef/#ifndef directives
|
||||
if stripped.startswith(("#if", "#ifdef", "#ifndef")) and not version_match:
|
||||
# Not a TORCH_FEATURE_VERSION block, just a regular conditional
|
||||
preprocessor_stack.append((False, None))
|
||||
continue
|
||||
|
||||
# Track #endif directives
|
||||
if stripped.startswith("#endif"):
|
||||
if preprocessor_stack:
|
||||
is_version_block, _ = preprocessor_stack.pop()
|
||||
# If we just closed a version block, check if we're still in one
|
||||
if is_version_block:
|
||||
# Look for any remaining version blocks in the stack
|
||||
inside_version_block = False
|
||||
current_version_macro = None
|
||||
for is_ver, ver_macro in reversed(preprocessor_stack):
|
||||
if is_ver:
|
||||
inside_version_block = True
|
||||
current_version_macro = ver_macro
|
||||
break
|
||||
continue
|
||||
|
||||
# Track #else and #elif (they don't change the stack depth, but exit version blocks)
|
||||
if stripped.startswith(("#else", "#elif")):
|
||||
# If we're in a version block, exit it (the #else branch is not versioned)
|
||||
if inside_version_block and preprocessor_stack:
|
||||
# Check if the topmost block is a version block
|
||||
if preprocessor_stack[-1][0]:
|
||||
inside_version_block = False
|
||||
current_version_macro = None
|
||||
continue
|
||||
|
||||
# Skip other preprocessor directives
|
||||
if stripped.startswith("#"):
|
||||
continue
|
||||
|
||||
# Track extern "C" blocks
|
||||
if extern_c_pattern.search(stripped):
|
||||
inside_extern_c = True
|
||||
continue
|
||||
if extern_c_end_pattern.search(stripped):
|
||||
inside_extern_c = False
|
||||
continue
|
||||
|
||||
# Check for function declarations
|
||||
if inside_extern_c:
|
||||
is_function_decl = any(
|
||||
pattern.match(stripped) for pattern in function_decl_patterns
|
||||
)
|
||||
|
||||
if is_function_decl:
|
||||
# Check if this is a newly added line
|
||||
is_new_line = line_num in added_lines
|
||||
|
||||
if not inside_version_block:
|
||||
# Function declaration outside of version block
|
||||
lint_messages.append(
|
||||
LintMessage(
|
||||
path=filename,
|
||||
line=line_num,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="unversioned-function-declaration",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"Function declaration found outside of TORCH_FEATURE_VERSION block. "
|
||||
f"All function declarations must be wrapped in:\n"
|
||||
f"{expected_version_check}\n"
|
||||
f"// ... your declarations ...\n"
|
||||
f"#endif // TORCH_FEATURE_VERSION >= {expected_version_macro}"
|
||||
),
|
||||
)
|
||||
)
|
||||
elif is_new_line and current_version_macro != expected_version_macro:
|
||||
# New function declaration using wrong version macro
|
||||
lint_messages.append(
|
||||
LintMessage(
|
||||
path=filename,
|
||||
line=line_num,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="wrong-version-for-new-function",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=(
|
||||
f"New function declaration should use {expected_version_macro}, "
|
||||
f"but is wrapped in {current_version_macro}. "
|
||||
f"New additions in this commit must use the current version:\n"
|
||||
f"{expected_version_check}\n"
|
||||
f"// ... your declarations ...\n"
|
||||
f"#endif // TORCH_FEATURE_VERSION >= {expected_version_macro}"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return lint_messages
|
||||
|
||||
|
||||
def check_file(filename: str) -> list[LintMessage]:
|
||||
"""
|
||||
Check if the file is stable/c/shim.h and lint it.
|
||||
"""
|
||||
# Only lint the specific file
|
||||
if not filename.endswith("torch/csrc/stable/c/shim.h"):
|
||||
return []
|
||||
|
||||
return parse_shim_file(filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="stable shim version linter",
|
||||
fromfile_prefix_chars="@",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"filenames",
|
||||
nargs="+",
|
||||
help="paths to lint",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
format="<%(threadName)s:%(levelname)s> %(message)s",
|
||||
level=logging.NOTSET
|
||||
if args.verbose
|
||||
else logging.DEBUG
|
||||
if len(args.filenames) < 1000
|
||||
else logging.INFO,
|
||||
stream=sys.stderr,
|
||||
)
|
||||
|
||||
lint_messages = []
|
||||
for filename in args.filenames:
|
||||
lint_messages.extend(check_file(filename))
|
||||
|
||||
for lint_message in lint_messages:
|
||||
print(json.dumps(lint_message._asdict()), flush=True)
|
||||
@ -663,6 +663,9 @@ class SymFloat:
|
||||
def __float__(self):
|
||||
return self.node.guard_float("", 0)
|
||||
|
||||
def __int__(self):
|
||||
return self.__trunc__().__int__()
|
||||
|
||||
# Symbolic power does NOT work with negative base, this is to avoid
|
||||
# potential complex outputs
|
||||
def __pow__(self, other):
|
||||
@ -811,6 +814,15 @@ class SymBool:
|
||||
# Force specialization
|
||||
return hash(builtins.bool(self))
|
||||
|
||||
def __sym_float__(self):
|
||||
"""
|
||||
Provides a SymFloat representation (0.0 or 1.0) for this SymBool.
|
||||
Called by torch.sym_float() when casting SymBool to float.
|
||||
"""
|
||||
from torch.fx.experimental.sym_node import wrap_node
|
||||
|
||||
return wrap_node(self.node.sym_float())
|
||||
|
||||
|
||||
def sym_not(a):
|
||||
r"""SymInt-aware utility for logical negation.
|
||||
|
||||
@ -739,6 +739,12 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
||||
@ -1000,7 +1000,10 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
||||
import inspect
|
||||
|
||||
if isinstance(mod, torch.nn.Module):
|
||||
mod = mod.forward
|
||||
if len(mod._forward_pre_hooks) == 0 and len(mod._forward_hooks) == 0:
|
||||
mod = mod.forward
|
||||
else:
|
||||
mod = mod.__call__
|
||||
if hasattr(mod, "__self__"):
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return mod.__func__, mod.__self__
|
||||
|
||||
@ -42,7 +42,7 @@ import weakref
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from os.path import dirname, join
|
||||
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -395,6 +395,13 @@ class OptimizedModule(torch.nn.Module):
|
||||
self._initialize()
|
||||
self.training = self._orig_mod.training
|
||||
|
||||
def __len__(self) -> int:
|
||||
# Proxy the len call to the original module
|
||||
if isinstance(self._orig_mod, Sized):
|
||||
return len(self._orig_mod)
|
||||
# Mimic python's default behavior for objects without a length
|
||||
raise TypeError(f"{type(self._orig_mod).__name__} does not support len()")
|
||||
|
||||
def _initialize(self) -> None:
|
||||
# Do this stuff in constructor to lower overhead slightly
|
||||
if isinstance(self.dynamo_ctx, DisableContext):
|
||||
|
||||
@ -16,6 +16,7 @@ from torch._dynamo.eval_frame import argument_names, check_user_input_output
|
||||
from torch._dynamo.exc import UserErrorType
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||
from torch._export.utils import _compiling_state_context
|
||||
from torch._guards import TracingContext
|
||||
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
|
||||
from torch.fx import Node
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
@ -449,6 +450,14 @@ def _suggest_or_raise_constraint_violation(
|
||||
raise constraint_violation_error
|
||||
|
||||
|
||||
def _normalize_shuffle_graph(shuffle_gm: torch.fx.GraphModule) -> None:
|
||||
shuffle_gm.graph.eliminate_dead_code()
|
||||
shuffle_gm.recompile()
|
||||
for name, buffer in list(shuffle_gm.named_buffers()):
|
||||
delattr(shuffle_gm, name)
|
||||
setattr(shuffle_gm, name, buffer)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PyTreeifyOutput:
|
||||
graph_module: torch.fx.GraphModule
|
||||
@ -525,6 +534,7 @@ def pytreeify(
|
||||
in_shuffle_graph = make_fx(
|
||||
InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True
|
||||
)(*flat_real_args)
|
||||
_normalize_shuffle_graph(in_shuffle_graph)
|
||||
|
||||
output_node = next(iter(reversed(backend_input.graph_module.graph.nodes)))
|
||||
|
||||
@ -572,6 +582,7 @@ def pytreeify(
|
||||
out_shuffle_graph = make_fx(
|
||||
out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True
|
||||
)(*flat_out_shuffle_args)
|
||||
_normalize_shuffle_graph(out_shuffle_graph)
|
||||
|
||||
assert out_shuffle.out_spec is not None
|
||||
return PyTreeifyOutput(
|
||||
@ -650,6 +661,10 @@ def dynamo_graph_capture_for_export(
|
||||
)
|
||||
assert out.backend_input is not None
|
||||
graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined]
|
||||
graph_module.meta["fake_mode"].allow_non_fake_inputs = True
|
||||
tracing_context = TracingContext(graph_module.meta["fake_mode"])
|
||||
tracing_context.tensor_to_context = out.backend_input.tensor_to_context # type: ignore[attr-defined]
|
||||
graph_module.meta["tracing_context"] = tracing_context
|
||||
return graph_module
|
||||
|
||||
return inner
|
||||
|
||||
@ -1734,6 +1734,14 @@
|
||||
}
|
||||
],
|
||||
"GB0175": [
|
||||
{
|
||||
"Gb_type": "builtin isinstance() cannot determine type of argument",
|
||||
"Context": "isinstance({arg}, {isinstance_type_var})",
|
||||
"Explanation": "Dynamo doesn't have a rule to determine the type of argument {arg}",
|
||||
"Hints": [
|
||||
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||
]
|
||||
},
|
||||
{
|
||||
"Gb_type": "builtin isinstance() cannot determine type of argument",
|
||||
"Context": "isinstance({arg}, {isinstance_type})",
|
||||
@ -2915,5 +2923,19 @@
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0287": [
|
||||
{
|
||||
"Gb_type": "unsupported type.__dict__['__annotations__'].__get__ call",
|
||||
"Context": "call_function {self}, args: {args}, kwargs: {kwargs}",
|
||||
"Explanation": "`torch.compile` only supports calling type.__dict__['__annotations__'].__get__ on a single constant argument (i.e. a type).",
|
||||
"Hints": [
|
||||
"Make sure your call to type.__dict__['__annotations__'] only has ",
|
||||
"one positional argument (no keyword arguments).",
|
||||
"Make sure the argument to type.__dict__['__annotations__'] is a constant ",
|
||||
"(i.e. type). For example, `object`, `int`, `MyCustomClass`.",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
119
torch/_dynamo/polyfills/heapq.py
Normal file
119
torch/_dynamo/polyfills/heapq.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""
|
||||
Python polyfills for heapq
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
import importlib
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from ..decorators import substitute_in_graph
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
# Partially copied from CPython test/support/import_helper.py
|
||||
# https://github.com/python/cpython/blob/bb8791c0b75b5970d109e5557bfcca8a578a02af/Lib/test/support/import_helper.py
|
||||
def _save_and_remove_modules(names: set[str]) -> dict[str, ModuleType]:
|
||||
orig_modules = {}
|
||||
prefixes = tuple(name + "." for name in names)
|
||||
for modname in list(sys.modules):
|
||||
if modname in names or modname.startswith(prefixes):
|
||||
orig_modules[modname] = sys.modules.pop(modname)
|
||||
return orig_modules
|
||||
|
||||
|
||||
def import_fresh_module(name: str, blocked: list[str]) -> ModuleType:
|
||||
# Keep track of modules saved for later restoration as well
|
||||
# as those which just need a blocking entry removed
|
||||
names = {name, *blocked}
|
||||
orig_modules = _save_and_remove_modules(names)
|
||||
for modname in blocked:
|
||||
sys.modules[modname] = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
return importlib.import_module(name)
|
||||
finally:
|
||||
_save_and_remove_modules(names)
|
||||
sys.modules.update(orig_modules)
|
||||
|
||||
|
||||
# Import the pure Python heapq module, blocking the C extension
|
||||
py_heapq = import_fresh_module("heapq", blocked=["_heapq"])
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_heapify_max",
|
||||
"_heappop_max",
|
||||
"_heapreplace_max",
|
||||
"heapify",
|
||||
"heappop",
|
||||
"heappush",
|
||||
"heappushpop",
|
||||
"heapreplace",
|
||||
"merge",
|
||||
"nlargest",
|
||||
"nsmallest",
|
||||
]
|
||||
|
||||
|
||||
@substitute_in_graph(heapq._heapify_max)
|
||||
def _heapify_max(heap: list[_T], /) -> None:
|
||||
return py_heapq._heapify_max(heap)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq._heappop_max) # type: ignore[attr-defined]
|
||||
def _heappop_max(heap: list[_T]) -> _T:
|
||||
return py_heapq._heappop_max(heap)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq._heapreplace_max) # type: ignore[attr-defined]
|
||||
def _heapreplace_max(heap: list[_T], item: _T) -> _T:
|
||||
return py_heapq._heapreplace_max(heap, item)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heapify)
|
||||
def heapify(heap: list[_T], /) -> None:
|
||||
return py_heapq.heapify(heap)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heappop)
|
||||
def heappop(heap: list[_T], /) -> _T:
|
||||
return py_heapq.heappop(heap)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heappush)
|
||||
def heappush(heap: list[_T], item: _T) -> None:
|
||||
return py_heapq.heappush(heap, item)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heappushpop)
|
||||
def heappushpop(heap: list[_T], item: _T) -> _T:
|
||||
return py_heapq.heappushpop(heap, item)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.heapreplace)
|
||||
def heapreplace(heap: list[_T], item: _T) -> _T:
|
||||
return py_heapq.heapreplace(heap, item)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.merge) # type: ignore[arg-type]
|
||||
def merge(*iterables, key=None, reverse=False): # type: ignore[no-untyped-def]
|
||||
return py_heapq.merge(*iterables, key=key, reverse=reverse)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.nlargest) # type: ignore[arg-type]
|
||||
def nlargest(n, iterable, key=None): # type: ignore[no-untyped-def]
|
||||
return py_heapq.nlargest(n, iterable, key=key)
|
||||
|
||||
|
||||
@substitute_in_graph(heapq.nsmallest) # type: ignore[arg-type]
|
||||
def nsmallest(n, iterable, key=None): # type: ignore[no-untyped-def]
|
||||
return py_heapq.nsmallest(n, iterable, key=key)
|
||||
@ -405,6 +405,7 @@ isolate_fails_code_str = None
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
kernel._fn_name
|
||||
if isinstance(kernel, JITFunction)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
else kernel.fn._fn_name
|
||||
)
|
||||
fn_name = fn_name.split(".")[-1]
|
||||
|
||||
@ -218,7 +218,7 @@ class CPythonTestCase(TestCase):
|
||||
if m:
|
||||
test_py_ver = tuple(map(int, m.group().removeprefix(prefix).split("_")))
|
||||
py_ver = sys.version_info[:2]
|
||||
if py_ver < test_py_ver:
|
||||
if py_ver != test_py_ver:
|
||||
expected = ".".join(map(str, test_py_ver))
|
||||
got = ".".join(map(str, py_ver))
|
||||
raise unittest.SkipTest(
|
||||
|
||||
@ -2707,6 +2707,7 @@ def to_subclass(t: Any, cls: type) -> Any:
|
||||
dict_getitem = dict.__getitem__
|
||||
|
||||
|
||||
@torch.fx.wrap
|
||||
def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any:
|
||||
# Call dict(d) to prevent calling overridden __iter__/keys
|
||||
dict_class = dict
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -61,7 +61,11 @@ from ..utils import (
|
||||
raise_args_mismatch,
|
||||
tuple_methods,
|
||||
)
|
||||
from .base import raise_type_error_exc, VariableTracker
|
||||
from .base import (
|
||||
AsPythonConstantNotImplementedError,
|
||||
raise_type_error_exc,
|
||||
VariableTracker,
|
||||
)
|
||||
from .constant import ConstantVariable
|
||||
from .functions import NestedUserFunctionVariable, UserFunctionVariable
|
||||
from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
|
||||
@ -1260,6 +1264,38 @@ class MethodWrapperVariable(VariableTracker):
|
||||
return variables.BuiltinVariable(object).call_method(
|
||||
tx, wrapper_name, [self_obj, *args], kwargs
|
||||
)
|
||||
elif (
|
||||
sys.version_info >= (3, 14)
|
||||
# for some reason, even if the below check passes,
|
||||
# self.method_wrapper may not be the same as type.__dict__["__annotations__"].__get__
|
||||
and self_obj is type.__dict__["__annotations__"]
|
||||
and wrapper_name == "__get__"
|
||||
):
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
if len(args) == 1 and not kwargs:
|
||||
try:
|
||||
return SourcelessBuilder.create(
|
||||
tx, self.method_wrapper(args[0].as_python_constant())
|
||||
)
|
||||
except AttributeError:
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
except AsPythonConstantNotImplementedError:
|
||||
pass
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="unsupported type.__dict__['__annotations__'].__get__ call",
|
||||
context=f"call_function {self}, args: {args}, kwargs: {kwargs}",
|
||||
explanation="`torch.compile` only supports calling type.__dict__['__annotations__'].__get__ "
|
||||
"on a single constant argument (i.e. a type).",
|
||||
hints=[
|
||||
"Make sure your call to type.__dict__['__annotations__'] only has "
|
||||
"one positional argument (no keyword arguments).",
|
||||
"Make sure the argument to type.__dict__['__annotations__'] is a constant "
|
||||
"(i.e. type). For example, `object`, `int`, `MyCustomClass`.",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ import operator
|
||||
import textwrap
|
||||
import traceback
|
||||
import types
|
||||
from collections.abc import Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -631,7 +632,7 @@ class TensorVariable(VariableTracker):
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
@ -834,12 +834,13 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
@register(torch.full)
|
||||
def handle_full(self, tx, size, fill_value, **kwargs):
|
||||
if isinstance(fill_value, TensorVariable):
|
||||
result = TorchInGraphFunctionVariable(
|
||||
torch.ops.aten._local_scalar_dense
|
||||
).call_function(tx, [fill_value], {})
|
||||
return TorchInGraphFunctionVariable(torch.full).call_function(
|
||||
tx, [size, result], kwargs
|
||||
# Decompose: create empty tensor and fill it
|
||||
# This avoids the scalar extraction at compile time
|
||||
empty_result = TorchInGraphFunctionVariable(torch.empty).call_function(
|
||||
tx, [size], kwargs
|
||||
)
|
||||
# Call fill_ method on the empty tensor
|
||||
return empty_result.call_method(tx, "fill_", [fill_value], {})
|
||||
|
||||
@register(torch._foreach_lerp_)
|
||||
def handle_inplace_foreach_lerp_scalar(
|
||||
|
||||
@ -29,6 +29,7 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from types import TracebackType
|
||||
from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING
|
||||
|
||||
@ -722,12 +723,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: "list[VariableTracker]",
|
||||
args: Sequence[VariableTracker],
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
# This code block implements inlining the __torch_function__ override
|
||||
# of `call_method`.
|
||||
tf_args = [self] + args
|
||||
tf_args = [self] + list(args)
|
||||
if can_dispatch_torch_function(tx, tf_args, kwargs):
|
||||
import torch
|
||||
|
||||
|
||||
@ -7,9 +7,9 @@
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
from collections.abc import Callable, Generator
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, Sequence, TypeAlias
|
||||
from typing import Any, Optional, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
@ -264,6 +264,7 @@ def generate_ttir(
|
||||
|
||||
assert isinstance(kernel, JITFunction)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
context = triton._C.libtriton.ir.context()
|
||||
target = triton.runtime.driver.active.get_current_target()
|
||||
backend = triton.compiler.compiler.make_backend(target)
|
||||
@ -305,6 +306,7 @@ def generate_ttir(
|
||||
base_tensor = torch.empty(
|
||||
[elements_per_dim] * len(block_shape), dtype=a.dtype
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape)
|
||||
elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
|
||||
with torch._C._DisableTorchDispatch():
|
||||
@ -368,6 +370,7 @@ def generate_ttir(
|
||||
|
||||
target = triton.runtime.driver.active.get_current_target()
|
||||
backend_ = triton.compiler.compiler.make_backend(target)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return backend_.get_attrs_descriptor(args, kernel.params)
|
||||
else:
|
||||
assert (
|
||||
@ -384,6 +387,7 @@ def generate_ttir(
|
||||
except TypeError: # Unknown arg `specialize_extra`
|
||||
# Older versions of Triton take specialize_extra as an arg to specialize_impl
|
||||
specialize_impl = functools.partial(
|
||||
# pyrefly: ignore # missing-argument
|
||||
triton.runtime.jit.create_specialize_impl(),
|
||||
specialize_extra=backend.get_arg_specialization,
|
||||
)
|
||||
@ -468,6 +472,7 @@ def generate_ttir(
|
||||
if i not in constexprs
|
||||
}
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
triton._C.libtriton.ir.load_dialects(context)
|
||||
backend.load_dialects(context)
|
||||
|
||||
@ -477,22 +482,29 @@ def generate_ttir(
|
||||
# backward compatibility here.
|
||||
make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
|
||||
get_codegen_implementation_sig_params = len(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
inspect.signature(backend.get_codegen_implementation).parameters
|
||||
)
|
||||
if make_ir_sig_params == 2:
|
||||
# pyrefly: ignore # missing-argument
|
||||
ttir_module = src.make_ir(options, context)
|
||||
elif make_ir_sig_params == 3:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation()
|
||||
# pyrefly: ignore # missing-argument
|
||||
ttir_module = src.make_ir(options, codegen_fns, context)
|
||||
elif make_ir_sig_params == 4:
|
||||
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation(*codegen_args)
|
||||
module_map = backend.get_module_map()
|
||||
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
|
||||
else:
|
||||
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
|
||||
# pyrefly: ignore # missing-attribute
|
||||
codegen_fns = backend.get_codegen_implementation(*codegen_args)
|
||||
module_map = backend.get_module_map()
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
ttir_module = src.make_ir(target, options, codegen_fns, module_map, context)
|
||||
if not ttir_module.verify():
|
||||
raise RuntimeError("Verification for TTIR module has failed")
|
||||
@ -1102,6 +1114,7 @@ def triton_kernel_wrapper_mutation_dense(
|
||||
from triton.tools.tensor_descriptor import TensorDescriptor
|
||||
|
||||
block_shape = stable_meta[0]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape)
|
||||
|
||||
# move as many positional arguments from dicts to args as we
|
||||
@ -1658,6 +1671,7 @@ class TritonHOPifier:
|
||||
"Passing multiple @triton.autotune decorators is not supported. "
|
||||
"Please use a single @triton.autotune decorator instead."
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
iter_kernel = iter_kernel.fn
|
||||
|
||||
# Process the @triton.heuristics decorator:
|
||||
@ -1868,6 +1882,7 @@ class TritonHOPifier:
|
||||
|
||||
# Both for grid's meta as well as for the kernel, we need combined
|
||||
# args and kwargs combined and normalized
|
||||
# pyrefly: ignore # missing-attribute
|
||||
combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
|
||||
|
||||
# precompute the grid for the kernel
|
||||
@ -2061,6 +2076,7 @@ class TraceableTritonKernelWrapper:
|
||||
kernel_idx: Optional[int],
|
||||
grid: Optional["TritonGridType"],
|
||||
) -> None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.kernel = None
|
||||
self.grid = None
|
||||
tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
|
||||
|
||||
@ -2,8 +2,9 @@ import json
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor.analysis.device_info import DeviceInfo, lookup_device_info
|
||||
@ -75,7 +76,9 @@ def _slow_conv2d_adapter(
|
||||
return conv_adapter(tuple(tmp), tuple(tmp2))
|
||||
|
||||
|
||||
@register_adapter(["convolution", "_convolution", "cudnn_convolution"])
|
||||
@register_adapter(
|
||||
["convolution", "_convolution", "cudnn_convolution", "convolution_overrideable"]
|
||||
)
|
||||
def conv_adapter(
|
||||
shapes: tuple[Any, ...], concrete: tuple[Any, ...]
|
||||
) -> tuple[tuple[Any], dict[Any, Any]]:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user