mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 07:24:54 +08:00
Compare commits
39 Commits
ciflow/tru
...
tianren/dy
| Author | SHA1 | Date | |
|---|---|---|---|
| 9f8b4700b5 | |||
| cdf0a9c21f | |||
| 115016f1a2 | |||
| 971e6ca434 | |||
| e8d411e7f7 | |||
| 2e5233d7bd | |||
| 514dd96376 | |||
| 9ae62fcc18 | |||
| ae71b0e163 | |||
| 5b6ff8148d | |||
| 1f7e4343e7 | |||
| b21856f5fc | |||
| 259ba0ecab | |||
| 051f1fe8e3 | |||
| ee387c43fe | |||
| 3a944661d6 | |||
| 56034074ca | |||
| 8def619bbe | |||
| 61883a5787 | |||
| d8ada1ee76 | |||
| fe841a1db4 | |||
| b65829b84f | |||
| b0e0ae97ba | |||
| f44a1ddcb2 | |||
| 184e2cbc89 | |||
| 416421c7c4 | |||
| bd99ae3315 | |||
| ce8672c24f | |||
| 402c465030 | |||
| 573a79fffa | |||
| 4945180468 | |||
| 1df723e6f5 | |||
| f9b81e23e4 | |||
| ffe6cc39c7 | |||
| db1f3f6901 | |||
| 43041f0a43 | |||
| dc00842b81 | |||
| f1a129a6d0 | |||
| fad48ffa62 |
@ -30,7 +30,6 @@ into a tarball, with the following structure:
|
||||
More specifically, `build_magma.sh` copies over the relevant files from the `package_files` directory depending on the ROCm version.
|
||||
Outputted binaries should be in the `output` folder.
|
||||
|
||||
|
||||
## Pushing
|
||||
|
||||
Packages can be uploaded to an S3 bucket using:
|
||||
|
||||
@ -96,7 +96,6 @@ function pip_build_and_install() {
|
||||
python3 -m pip wheel \
|
||||
--no-build-isolation \
|
||||
--no-deps \
|
||||
--no-use-pep517 \
|
||||
-w "${wheel_dir}" \
|
||||
"${build_target}"
|
||||
fi
|
||||
|
||||
2
.github/actionlint.yaml
vendored
2
.github/actionlint.yaml
vendored
@ -63,7 +63,7 @@ self-hosted-runner:
|
||||
- linux.rocm.gpu.gfx942.1
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
- rocm-docker
|
||||
- linux.rocm.gfx942.docker-cache
|
||||
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
|
||||
- macos-m1-stable
|
||||
- macos-m1-14
|
||||
|
||||
55
.github/workflows/docker-cache-mi300.yml
vendored
55
.github/workflows/docker-cache-mi300.yml
vendored
@ -1,55 +0,0 @@
|
||||
name: docker-cache-mi300
|
||||
|
||||
on:
|
||||
# run every 6 hours
|
||||
schedule:
|
||||
- cron: 0 0,6,12,18 * * *
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
runs-on: rocm-docker
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
no-sudo: true
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
push: false
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Tar and upload to S3 bucket
|
||||
run: |
|
||||
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress
|
||||
108
.github/workflows/docker-cache-rocm.yml
vendored
Normal file
108
.github/workflows/docker-cache-rocm.yml
vendored
Normal file
@ -0,0 +1,108 @@
|
||||
name: docker-cache-rocm
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [docker-builds]
|
||||
# TODO: Uncomment before merging
|
||||
#branches: [main, release]
|
||||
types:
|
||||
- completed
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
actions: read
|
||||
|
||||
jobs:
|
||||
download-docker-builds-artifacts:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: download-docker-builds-artifacts
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
|
||||
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
|
||||
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
|
||||
steps:
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4.1.7
|
||||
with:
|
||||
run-id: ${{ github.event.workflow_run.id }}
|
||||
path: ./docker-builds-artifacts
|
||||
merge-multiple: true
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Process artifacts
|
||||
id: process-artifacts
|
||||
run: |
|
||||
ls -R ./docker-builds-artifacts
|
||||
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
|
||||
cat "${GITHUB_OUTPUT}"
|
||||
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
needs: download-docker-builds-artifacts
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
runner: [linux.rocm.gfx942.docker-cache]
|
||||
docker-image: [
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
|
||||
]
|
||||
runs-on: "${{ matrix.runner }}"
|
||||
steps:
|
||||
- name: debug
|
||||
run: |
|
||||
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
|
||||
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Generate ghrc.io tag
|
||||
id: ghcr-io-tag
|
||||
run: |
|
||||
ecr_image="${{ matrix.docker-image }}"
|
||||
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
|
||||
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
|
||||
|
||||
- name: Save as tarball
|
||||
run: |
|
||||
docker_image_tag=${{ matrix.docker-image }}
|
||||
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
|
||||
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
|
||||
ref_name=${{ github.event.workflow_run.head_branch }}
|
||||
if [[ $ref_name =~ "release/" ]]; then
|
||||
ref_suffix="release"
|
||||
elif [[ $ref_name == "main" ]]; then
|
||||
ref_suffix="main"
|
||||
else
|
||||
# TODO: Remove below
|
||||
ref_suffix="main"
|
||||
# echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
|
||||
fi
|
||||
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
|
||||
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
|
||||
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
|
||||
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar
|
||||
2
.github/workflows/inductor-rocm-mi200.yml
vendored
2
.github/workflows/inductor-rocm-mi200.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: inductor-rocm
|
||||
name: inductor-rocm-mi200
|
||||
|
||||
on:
|
||||
schedule:
|
||||
|
||||
2
.github/workflows/rocm-mi200.yml
vendored
2
.github/workflows/rocm-mi200.yml
vendored
@ -1,4 +1,4 @@
|
||||
name: rocm
|
||||
name: rocm-mi200
|
||||
|
||||
on:
|
||||
push:
|
||||
|
||||
@ -18,6 +18,8 @@ Please report security issues using https://github.com/pytorch/pytorch/security/
|
||||
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
**Note on crashes and out of bounds access**: PyTorch is a computational framework that performs operations on behalf of the caller. Like many low-level libraries, PyTorch generally does not validate all inputs to every function—the responsibility for providing valid arguments lies with the calling code. While crashes and out of bounds memory access should be reported as bugs, they are generally not considered security vulnerabilities in PyTorch's threat model.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
@ -142,6 +142,7 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
|
||||
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
|
||||
auto batch_sizes_t = _batch_sizes.contiguous();
|
||||
checkLongTensor(batch_sizes_t);
|
||||
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
|
||||
|
||||
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
|
||||
int64_t max_batch_size = batch_sizes[0];
|
||||
|
||||
@ -669,9 +669,12 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
// On non CK system(w/ ROCm), make sure use_fast_path is false
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
@ -680,7 +683,11 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
|
||||
@ -47,6 +47,7 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/for_each.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/gather.h>
|
||||
|
||||
@ -50,7 +50,7 @@ nfnet_l0,pass,7
|
||||
|
||||
|
||||
|
||||
repvgg_a2,fail_accuracy,7
|
||||
repvgg_a2,pass,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -14,6 +14,10 @@ Utils
|
||||
|
||||
sdpa_kernel
|
||||
SDPBackend
|
||||
register_flash_attention_impl
|
||||
activate_flash_attention_impl
|
||||
list_flash_attention_impls
|
||||
current_flash_attention_impl
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
|
||||
pip install ninja
|
||||
|
||||
# Install onnx
|
||||
pip install --no-use-pep517 -e "$tp2_dir/onnx"
|
||||
pip install -e "$tp2_dir/onnx"
|
||||
|
||||
# Install caffe2 and pytorch
|
||||
pip install -r "$top_dir/caffe2/requirements.txt"
|
||||
|
||||
@ -140,6 +140,11 @@ static void initDeviceStreamState(DeviceIndex device_index) {
|
||||
static void initOpenRegStreamsOnce() {
|
||||
c10::call_once(init_flag, initGlobalStreamState);
|
||||
|
||||
for (const auto i : c10::irange(num_devices)) {
|
||||
c10::call_once(
|
||||
device_flags[i], initDeviceStreamState, static_cast<DeviceIndex>(i));
|
||||
}
|
||||
|
||||
if (current_streams) {
|
||||
return;
|
||||
}
|
||||
@ -202,8 +207,6 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
|
||||
if (device_index == -1) {
|
||||
device_index = current_device();
|
||||
}
|
||||
c10::call_once(
|
||||
device_flags[device_index], initDeviceStreamState, device_index);
|
||||
auto pri_idx =
|
||||
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
|
||||
const auto idx = get_idx(priority_counters[device_index][pri_idx]);
|
||||
|
||||
@ -180,6 +180,47 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
||||
del model
|
||||
del optim
|
||||
|
||||
def _test_tracker_multihandler_hook(self):
|
||||
"""Should run without KeyError."""
|
||||
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.norm1 = nn.RMSNorm(dim)
|
||||
self.output1 = nn.Linear(dim, dim)
|
||||
self.norm2 = nn.RMSNorm(dim)
|
||||
self.output2 = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.norm1(x)
|
||||
x = self.output1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.output2(x)
|
||||
return x
|
||||
|
||||
gc.collect()
|
||||
torch.manual_seed(42)
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
|
||||
with torch.device(dev):
|
||||
model = TestModule(128)
|
||||
|
||||
mesh = init_device_mesh(dev.type, (self.world_size,))
|
||||
fully_shard([model.norm1, model.output1], mesh=mesh)
|
||||
fully_shard([model.norm2, model.output2], mesh=mesh)
|
||||
fully_shard(model, mesh=mesh)
|
||||
|
||||
fmt = FSDPMemTracker(model)
|
||||
|
||||
with fmt:
|
||||
inp = torch.randn(16, 128, device=dev)
|
||||
y = model(inp)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
del inp
|
||||
del model
|
||||
|
||||
|
||||
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
|
||||
@property
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -371,6 +372,7 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
|
||||
# is producing a joint graph with backward region missing
|
||||
@unittest.expectedFailure
|
||||
def test_strict_export_parallelize_module_with_dtensor_input(self):
|
||||
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ import torch._functorch.config
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import default_partition, min_cut_rematerialization_partition
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
|
||||
)
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
@ -281,14 +281,7 @@ class ActivationCheckpointingViaTagsTests(
|
||||
|
||||
run(export_compiler)
|
||||
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function(self, device, partition_fn):
|
||||
def test_tags_function(self, device):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -304,22 +297,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
|
||||
def test_tags_function_via_global_checkpoint(self, device):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -334,28 +316,17 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_with_kwargs(self, device, partition_fn):
|
||||
def test_tags_function_with_kwargs(self, device):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn, torch.sin(x), y, use_reentrant=False
|
||||
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
@ -365,22 +336,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_sequential_layers(self, device, partition_fn):
|
||||
def test_tags_sequential_layers(self, device):
|
||||
def gn(x):
|
||||
x = x.cos()
|
||||
for _ in range(3):
|
||||
@ -401,22 +361,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
freqs=[2, 18],
|
||||
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_multiple_checkpoints(self, device, partition_fn):
|
||||
def test_tags_multiple_checkpoints(self, device):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -434,22 +383,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=6, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_module(self, device, partition_fn):
|
||||
def test_tags_module(self, device):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -473,22 +411,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
||||
)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_decomps(self, device, partition_fn):
|
||||
def test_tags_decomps(self, device):
|
||||
# Ensures that tags are passed on through decompositions as well
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -516,7 +443,6 @@ class ActivationCheckpointingViaTagsTests(
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=lambda: import_module(
|
||||
"torch._inductor.compile_fx"
|
||||
).select_decomp_table(),
|
||||
@ -776,14 +702,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device):
|
||||
def context_fn_must_recompute_mm():
|
||||
must_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -804,9 +723,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
),
|
||||
)
|
||||
|
||||
def _test(context_fn, bw_compiler, partition_fn):
|
||||
def _test(context_fn, bw_compiler):
|
||||
def gn(x):
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
|
||||
return torch.sigmoid(torch.matmul(x, x))
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
@ -820,14 +739,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
fw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freq=2,
|
||||
freq=1,
|
||||
op=torch.ops.aten.mm.default,
|
||||
)
|
||||
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@ -835,19 +754,17 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
context_fn=context_fn_must_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
||||
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
_test(
|
||||
context_fn=context_fn_no_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=4, # 2 bwd mm ops per fwd matmul
|
||||
freq=2, # 2 bwd mm ops per fwd matmul
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
def test_sac_with_partial_context_fn(self):
|
||||
@ -884,16 +801,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(
|
||||
self, device, partition_fn
|
||||
):
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -933,22 +841,15 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
|
||||
self, device, partition_fn
|
||||
self, device
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
@ -988,7 +889,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
@ -996,14 +897,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device):
|
||||
# Copy of the above test, but make sure that having a triton kernel in the
|
||||
# region does not error.
|
||||
def add_one(x):
|
||||
@ -1063,21 +957,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1120,21 +1007,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device):
|
||||
def _get_custom_policy(meta):
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1192,21 +1072,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
|
||||
def selective_checkpointing_context_fn(no_recompute_list):
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
@ -1245,21 +1118,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1297,21 +1163,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_list_ops(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
# recompute everything
|
||||
no_recompute_list = []
|
||||
@ -1347,7 +1206,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1358,14 +1217,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
"requires TorchDispatchMode + torch.compile work to complete"
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1405,7 +1257,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1413,14 +1265,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_random_op(self, device):
|
||||
for preserve_rng_state in [True, False]:
|
||||
|
||||
def selective_checkpointing_context_fn():
|
||||
@ -1467,7 +1312,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
|
||||
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
|
||||
@ -1479,14 +1324,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
|
||||
def test_compile_selective_checkpoint_invalid_context(self):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y)) * y
|
||||
|
||||
@ -1515,7 +1353,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "must generate a tuple of two `TorchDispatchMode`s"
|
||||
@ -1524,14 +1362,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
|
||||
def test_compile_selective_checkpoint_parametrization(self):
|
||||
def sac_policy():
|
||||
def _recomp_policy():
|
||||
def _custom_policy(ctx, func, *args, **kwargs):
|
||||
@ -1594,9 +1425,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
bw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freqs=[
|
||||
# 1 from mul recompute, 1 from mul backward
|
||||
# w/o CSE, we have one extra mul
|
||||
3 if partition_fn is default_partition else 2,
|
||||
2, # 1 from mul recompute, 1 from mul backward
|
||||
1,
|
||||
],
|
||||
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
||||
@ -1605,7 +1434,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
|
||||
model = MLPModule()
|
||||
|
||||
@ -2363,6 +2363,34 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(same(output, expected))
|
||||
assert cnt.frame_count == 1
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 13), "math.fma introduced in python 3.13")
|
||||
def test_math_fma(self):
|
||||
def fma_func(a, b, c):
|
||||
return math.fma(a, b, c)
|
||||
|
||||
# Test with scalar constants (constant folding path)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
cfma_scalars = torch._dynamo.optimize_assert(cnt)(fma_func)
|
||||
|
||||
assert cnt.frame_count == 0
|
||||
expected = fma_func(2.0, 3.0, 4.0)
|
||||
output = cfma_scalars(2.0, 3.0, 4.0)
|
||||
self.assertEqual(output, expected)
|
||||
assert cnt.frame_count == 0
|
||||
|
||||
# Test with tensors (Inductor path)
|
||||
cnt2 = torch._dynamo.testing.CompileCounter()
|
||||
cfma_tensors = torch._dynamo.optimize_assert(cnt2)(fma_func)
|
||||
|
||||
assert cnt2.frame_count == 0
|
||||
x = torch.tensor(2.0)
|
||||
y = torch.tensor(3.0)
|
||||
z = torch.tensor(4.0)
|
||||
expected_tensors = x * y + z
|
||||
output_tensors = cfma_tensors(x, y, z)
|
||||
torch.testing.assert_close(output_tensors, expected_tensors)
|
||||
assert cnt2.frame_count == 1
|
||||
|
||||
@make_test
|
||||
def test_numpy_meshgrid(x, y):
|
||||
r1, r2 = np.meshgrid(x.numpy(), y.numpy())
|
||||
|
||||
@ -335,6 +335,59 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@requires_multigpu()
|
||||
def test_new_event_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_event
|
||||
|
||||
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
e0_ind = new_event()
|
||||
with torch.Stream(device="cuda:1"):
|
||||
get_external_object_by_index(e0_ind).record()
|
||||
e1_ind = new_event()
|
||||
self.assertNotEqual(e0_ind, e1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(e0_ind),
|
||||
get_external_object_by_index(e1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=event_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_new_stream_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_stream
|
||||
|
||||
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
s0_ind = new_stream()
|
||||
s1_ind = new_stream()
|
||||
self.assertNotEqual(s0_ind, s1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(s0_ind),
|
||||
get_external_object_by_index(s1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=stream_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
@ -523,6 +576,23 @@ class <lambda>(torch.nn.Module):
|
||||
torch.accelerator.set_stream(original_stream)
|
||||
reset_user_object_tracking()
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck_wait_record_stream(self):
|
||||
from torch._dynamo.variables.streams import wait_stream
|
||||
from torch.library import opcheck
|
||||
|
||||
s0 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s2 = torch.Stream()
|
||||
store_user_object_weakrefs(s0, s1, s2)
|
||||
|
||||
sample_inputs = [
|
||||
(0, 1),
|
||||
(2, 0),
|
||||
]
|
||||
for args in sample_inputs:
|
||||
opcheck(wait_stream, args)
|
||||
|
||||
@requires_cuda
|
||||
def test_inductor_lowering(self):
|
||||
with patch("torch._inductor.config.implicit_fallbacks", False):
|
||||
|
||||
@ -331,7 +331,12 @@ class TestDynamismExpression(TestCase):
|
||||
return torch.ops.aten.slice.Tensor(*args)
|
||||
|
||||
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
|
||||
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
|
||||
dynamic_shapes = (
|
||||
{0: Dim("dim")},
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
torch.export.export(
|
||||
Slice(),
|
||||
inp,
|
||||
@ -5533,21 +5538,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
|
||||
w = Wrapped()
|
||||
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Detected mismatch between the structure of `inputs` and `dynamic_shapes`"
|
||||
": `inputs` has 2 elements, but `dynamic_shapes` has 1 elements",
|
||||
):
|
||||
export(w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})})
|
||||
else:
|
||||
compiled = export(
|
||||
w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})}
|
||||
)
|
||||
expected = w(*args)
|
||||
mod = compiled.module()
|
||||
got = mod(*args)
|
||||
self.assertTrue(torch.allclose(expected, got))
|
||||
compiled = export(w, args, dynamic_shapes=({0: batch}, {0: batch}))
|
||||
expected = w(*args)
|
||||
mod = compiled.module()
|
||||
got = mod(*args)
|
||||
self.assertTrue(torch.allclose(expected, got))
|
||||
|
||||
def test_dynamic_shapes_builder_basic(self):
|
||||
class M(torch.nn.Module):
|
||||
@ -17504,6 +17499,105 @@ def forward(self, x):
|
||||
exported_param_names = [name for name, _ in gm.named_parameters()]
|
||||
self.assertEqual(original_param_names, exported_param_names)
|
||||
|
||||
def test_export_compiled_model_with_nested_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, data_batch):
|
||||
return data_batch["a1"] + data_batch["a2"]
|
||||
|
||||
m = M()
|
||||
compiled_m = torch.compile(m)
|
||||
example_args = (
|
||||
{
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
},
|
||||
)
|
||||
dynamic_shapes = (
|
||||
{
|
||||
"a1": {0: Dim.DYNAMIC},
|
||||
"a2": {0: Dim.DYNAMIC},
|
||||
},
|
||||
)
|
||||
ep = export(
|
||||
compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(*example_args), compiled_m(*example_args))
|
||||
|
||||
def test_export_model_with_nested_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, data_batch):
|
||||
return data_batch["a1"] + data_batch["a2"]
|
||||
|
||||
m = M()
|
||||
example_args = (
|
||||
{
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
},
|
||||
)
|
||||
B = torch.export.Dim("batch", min=1, max=65536)
|
||||
dynamic_shapes = (
|
||||
{
|
||||
"a1": {0: B},
|
||||
"a2": {0: B},
|
||||
},
|
||||
)
|
||||
ep = export(m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(*example_args), m(*example_args))
|
||||
|
||||
def test_export_compiled_model_with_kwargs_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a1, a2):
|
||||
return a1 + a2
|
||||
|
||||
m = M()
|
||||
compiled_m = torch.compile(m)
|
||||
example_args = ()
|
||||
example_kwargs = {
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
}
|
||||
dynamic_shapes = {
|
||||
"a1": {0: Dim.DYNAMIC},
|
||||
"a2": {0: Dim.DYNAMIC},
|
||||
}
|
||||
ep = export(
|
||||
compiled_m,
|
||||
example_args,
|
||||
kwargs=example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(**example_kwargs), compiled_m(**example_kwargs))
|
||||
|
||||
def test_export_model_with_kwargs_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a1, a2):
|
||||
return a1 + a2
|
||||
|
||||
m = M()
|
||||
example_args = ()
|
||||
example_kwargs = {
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
}
|
||||
dynamic_shapes = {
|
||||
"a1": {0: Dim.DYNAMIC},
|
||||
"a2": {0: Dim.DYNAMIC},
|
||||
}
|
||||
ep = export(
|
||||
m,
|
||||
example_args,
|
||||
kwargs=example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(**example_kwargs), m(**example_kwargs))
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||
class TestExportCustomClass(TorchTestCase):
|
||||
|
||||
@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
|
||||
return grad_output * x, grad_output * x
|
||||
|
||||
def f(a, b):
|
||||
return FwBwMutation.apply(a, b).sin_().clone()
|
||||
return FwBwMutation.apply(a, b)
|
||||
|
||||
inps = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
@ -2689,22 +2689,17 @@ def forward(self, primals_1, primals_2):
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
|
||||
clone = torch.ops.aten.clone.default(mul)
|
||||
sin_ = torch.ops.aten.sin_.default(mul); mul = None
|
||||
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
|
||||
return (clone_1, add, clone)""",
|
||||
return (mul, add)""",
|
||||
)
|
||||
|
||||
# important bit: there is 1 mutation in the bw
|
||||
self.assertExpectedInline(
|
||||
bw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, add, clone, tangents_1):
|
||||
cos = torch.ops.aten.cos.default(clone); clone = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
|
||||
def forward(self, add, tangents_1):
|
||||
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
|
||||
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
|
||||
return (mul_2, None)""",
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
|
||||
return (mul_1, None)""",
|
||||
)
|
||||
|
||||
def test_fw_bw_mutation_no_functionalization2(self):
|
||||
|
||||
@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
self.assertEqual(len(mm_nodes), 4)
|
||||
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
|
||||
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)
|
||||
|
||||
@ -4101,6 +4101,53 @@ if HAS_CUDA_AND_TRITON:
|
||||
compiled_out = compiled_foo(x)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
# Use autotune_at_compile_time=True to test standalone_compile
|
||||
@parametrize("autotune_at_compile_time", [True, False])
|
||||
@config.patch("graph_partition", True)
|
||||
def test_graph_partition_kernel_reuse(self, autotune_at_compile_time):
|
||||
def foo(x):
|
||||
# partition 1
|
||||
x1 = x @ x
|
||||
y1 = x1 + 1
|
||||
z_cpu = y1.cpu() + 1
|
||||
# partition 2
|
||||
# partition 2 should reuse the fused triton kernel generated
|
||||
# in partition 1
|
||||
x2 = z_cpu.to("cuda") @ z_cpu.to("cuda")
|
||||
y2 = x2 + 1
|
||||
return y1, y2
|
||||
|
||||
with config.patch(
|
||||
"triton.autotune_at_compile_time", autotune_at_compile_time
|
||||
):
|
||||
compiled_foo = torch.compile(foo)
|
||||
x = torch.randn((20, 20), device="cuda")
|
||||
eager_out = foo(x)
|
||||
compiled_out, code = run_and_get_code(compiled_foo, x)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
if autotune_at_compile_time:
|
||||
# auto-tuning block should only appear once. We generate auto-tuning code
|
||||
# for all the kernels no matter if they are defined in the main graph or
|
||||
# subgraph, to avoid the overhead of executing multiple auto-tuning code blocks.
|
||||
FileCheck().check_count(
|
||||
"Compile-time auto-tuning block", 1, exactly=True
|
||||
).run(code[0])
|
||||
# triton_poi_fused_add_ should appear twice, first in the auto-tuning block,
|
||||
# and then in the main code block
|
||||
FileCheck().check_count(
|
||||
"def triton_poi_fused_add_", 2, exactly=True
|
||||
).run(code[0])
|
||||
# cpu kernel definition should only appence once, not in the auto-tuning block
|
||||
FileCheck().check_count(
|
||||
"cpp_fused__to_copy_add_1 = ", 1, exactly=True
|
||||
).run(code[0])
|
||||
else:
|
||||
# triton_poi_fused_add_ should appear once, because of kernel reuse
|
||||
FileCheck().check_count(
|
||||
"def triton_poi_fused_add_", 1, exactly=True
|
||||
).run(code[0])
|
||||
|
||||
def test_meta_tensor(self):
|
||||
def foobar(x, y):
|
||||
return x * 2, y * 3
|
||||
|
||||
@ -4,8 +4,9 @@ from functools import partial
|
||||
from unittest import skipIf
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.ir import Pointwise
|
||||
from torch._inductor.lowering import make_pointwise, register_lowering
|
||||
from torch._inductor.lowering import make_fallback, make_pointwise, register_lowering
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.virtualized import ops
|
||||
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
|
||||
@ -237,6 +238,17 @@ class TestCustomLowering(InductorTestCase):
|
||||
out2 = fn_opt(a, b)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
@config.patch(joint_graph_constant_folding=False)
|
||||
def test_constant_creation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + torch.tensor(1)
|
||||
|
||||
make_fallback(torch.ops.aten.lift_fresh_copy.default)
|
||||
self.assertTrue(
|
||||
torch.allclose(torch.compile(M())(torch.ones(3)), torch.ones(3) + 1)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
@ -430,6 +430,155 @@ class TestCustomOpAutoTune(TestCase):
|
||||
multi_param_op, (test_x, test_factor), expected_result, "MultiParam"
|
||||
)
|
||||
|
||||
@skipIfXpu
|
||||
def test_dynamic_range_tuning(self):
|
||||
"""Test dynamic input range-based autotuning.
|
||||
|
||||
Validates that different implementations can be selected automatically
|
||||
based on input dimensions using range parameters in CustomOpConfig.
|
||||
|
||||
This test demonstrates the simplified range-based API:
|
||||
- User provides CustomOpConfigs with range parameters
|
||||
- System groups configs by range and benchmarks implementations
|
||||
- System automatically selects the fastest implementation per range
|
||||
- If all ranges use same impl → direct use (fusion-friendly)
|
||||
- If different ranges use different impls → torch.cond dispatch
|
||||
"""
|
||||
test_op_name = f"test_lib::dynamic_range_{id(self)}"
|
||||
|
||||
def short_sequence_impl(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Optimized for short sequences (< 512): uses simple einsum."""
|
||||
return torch.einsum("bsh,h->bsh", x, weight)
|
||||
|
||||
def medium_sequence_impl(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Optimized for medium sequences (512-2048): uses chunked processing."""
|
||||
batch_size, seq_len, hidden_dim = x.shape
|
||||
chunk_size = 256
|
||||
chunks = []
|
||||
for start in range(0, seq_len, chunk_size):
|
||||
end = min(start + chunk_size, seq_len)
|
||||
chunk = x[:, start:end, :]
|
||||
chunks.append(chunk * weight)
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
def long_sequence_impl(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Optimized for long sequences (> 2048): uses reshape + broadcast."""
|
||||
return x * weight.view(1, 1, -1)
|
||||
|
||||
@torch.library.custom_op(test_op_name, mutates_args=())
|
||||
def dynamic_range_op(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Default implementation."""
|
||||
return x * weight
|
||||
|
||||
@dynamic_range_op.register_fake
|
||||
def _(x: torch.Tensor, weight: torch.Tensor):
|
||||
return torch.empty_like(x)
|
||||
|
||||
# Register with range-based configs (CLEAN API with dim_range tuple)
|
||||
# Each config specifies its range using tensor_name, dim_index, dim_range=(start, end)
|
||||
register_custom_op_autotuning(
|
||||
dynamic_range_op,
|
||||
configs=[
|
||||
# Range 1: [0, 512) - test all 3 implementations
|
||||
CustomOpConfig(
|
||||
short_sequence_impl,
|
||||
tensor_name="x",
|
||||
dim_index=1,
|
||||
dim_range=(0, 512),
|
||||
),
|
||||
CustomOpConfig(
|
||||
medium_sequence_impl,
|
||||
tensor_name="x",
|
||||
dim_index=1,
|
||||
dim_range=(0, 512),
|
||||
),
|
||||
CustomOpConfig(
|
||||
long_sequence_impl, tensor_name="x", dim_index=1, dim_range=(0, 512)
|
||||
),
|
||||
# Range 2: [512, 2048) - test all 3 implementations
|
||||
CustomOpConfig(
|
||||
short_sequence_impl,
|
||||
tensor_name="x",
|
||||
dim_index=1,
|
||||
dim_range=(512, 2048),
|
||||
),
|
||||
CustomOpConfig(
|
||||
medium_sequence_impl,
|
||||
tensor_name="x",
|
||||
dim_index=1,
|
||||
dim_range=(512, 2048),
|
||||
),
|
||||
CustomOpConfig(
|
||||
long_sequence_impl,
|
||||
tensor_name="x",
|
||||
dim_index=1,
|
||||
dim_range=(512, 2048),
|
||||
),
|
||||
# Range 3: [2048, inf) - test all 3 implementations
|
||||
CustomOpConfig(
|
||||
short_sequence_impl,
|
||||
tensor_name="x",
|
||||
dim_index=1,
|
||||
dim_range=(2048, float("inf")),
|
||||
),
|
||||
CustomOpConfig(
|
||||
medium_sequence_impl,
|
||||
tensor_name="x",
|
||||
dim_index=1,
|
||||
dim_range=(2048, float("inf")),
|
||||
),
|
||||
CustomOpConfig(
|
||||
long_sequence_impl,
|
||||
tensor_name="x",
|
||||
dim_index=1,
|
||||
dim_range=(2048, float("inf")),
|
||||
),
|
||||
],
|
||||
name="dynamic_range_autotuned",
|
||||
input_gen_fns={
|
||||
"x": lambda fake: torch.randn_like(fake, device=self.device) * 0.1,
|
||||
"weight": lambda fake: torch.ones_like(fake, device=self.device),
|
||||
},
|
||||
)
|
||||
|
||||
# Test different sequence lengths to trigger different ranges
|
||||
test_cases = [
|
||||
(2, 256, 128), # Short sequence (< 512)
|
||||
(2, 1024, 128), # Medium sequence (512-2048)
|
||||
(2, 4096, 128), # Long sequence (> 2048)
|
||||
]
|
||||
|
||||
for batch_size, seq_len, hidden_dim in test_cases:
|
||||
test_x = torch.randn(
|
||||
batch_size, seq_len, hidden_dim, device=self.device, dtype=self.dtype
|
||||
)
|
||||
test_weight = torch.ones(hidden_dim, device=self.device, dtype=self.dtype)
|
||||
|
||||
# Verify all implementations produce same result
|
||||
expected = test_x * test_weight
|
||||
|
||||
for impl_name, impl_fn in [
|
||||
("short", short_sequence_impl),
|
||||
("medium", medium_sequence_impl),
|
||||
("long", long_sequence_impl),
|
||||
]:
|
||||
result = impl_fn(test_x, test_weight)
|
||||
torch.testing.assert_close(
|
||||
result,
|
||||
expected,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"{impl_name} implementation differs for seq_len={seq_len}",
|
||||
)
|
||||
|
||||
# Test autotuning with compilation
|
||||
self._run_autotune_test(
|
||||
dynamic_range_op,
|
||||
(test_x, test_weight),
|
||||
expected,
|
||||
f"DynamicRange_seq{seq_len}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -31,7 +31,6 @@ from torch.testing._internal.common_utils import (
|
||||
serialTest,
|
||||
TEST_CUDA_MEM_LEAK_CHECK,
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
@ -93,17 +92,6 @@ if not torch._inductor.config.cpp_wrapper:
|
||||
("cuda",)
|
||||
)
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
# Tensor-likes are not close
|
||||
test_failures["test_dynamic_stride_nobreak"] = TestFailure(
|
||||
("cpu", "cuda"), is_skip=True
|
||||
)
|
||||
test_failures["test_item_to_inputs_kernel_nobreak"] = TestFailure(
|
||||
("cpu", "cuda"), is_skip=True
|
||||
)
|
||||
test_failures["test_unbacked_reduction"] = TestFailure(("cpu"), is_skip=True)
|
||||
|
||||
|
||||
if any(os.getenv("BUILD_ENVIRONMENT", "").endswith(x) for x in ("-debug", "-asan")):
|
||||
# Fails with TORCH_INTERNAL_ASSERT(!is_heap_allocated()), see https://github.com/pytorch/pytorch/issues/130073
|
||||
# After https://github.com/pytorch/pytorch/pull/161586, starts failing UBSAN so we can't even xfail.
|
||||
|
||||
@ -492,6 +492,36 @@ class PackedSequenceTest(TestCase):
|
||||
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
|
||||
)
|
||||
|
||||
def test_empty_packed_sequence(self):
|
||||
"""
|
||||
Regression test for https://github.com/pytorch/pytorch/issues/149622
|
||||
Tests that pad_packed_sequence and unpack_sequence handle empty tensors
|
||||
without segmentation fault (CVE-2025-2998, CVE-2025-2999)
|
||||
"""
|
||||
# Test case 1: pad_packed_sequence with empty tensors
|
||||
# Previously caused segmentation fault
|
||||
empty_data = torch.randn(0, 5)
|
||||
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
|
||||
empty_packed = rnn_utils.PackedSequence(
|
||||
empty_data, empty_batch_sizes, None, None
|
||||
)
|
||||
|
||||
# Should not crash - either return empty result or raise informative error
|
||||
with self.assertRaises(RuntimeError):
|
||||
rnn_utils.pad_packed_sequence(empty_packed, batch_first=True)
|
||||
|
||||
# Test case 2: unpack_sequence with empty tensors
|
||||
# Previously caused segmentation fault
|
||||
empty_data = torch.tensor([])
|
||||
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
|
||||
packed = rnn_utils.PackedSequence(
|
||||
data=empty_data, batch_sizes=empty_batch_sizes
|
||||
)
|
||||
|
||||
# Should not crash - either return empty list or raise informative error
|
||||
with self.assertRaises(RuntimeError):
|
||||
rnn_utils.unpack_sequence(packed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -2320,6 +2320,8 @@ if sys.version_info >= (3, 11):
|
||||
torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable
|
||||
torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
torch_c_binding_in_graph_functions["math.fma"] = TorchInGraphFunctionVariable
|
||||
|
||||
# In graph functions (including constant folding) that are not C bindings
|
||||
# NOTE: [Cacheability of in-graph torch functions]
|
||||
|
||||
@ -10,7 +10,10 @@ from torch.fx import has_side_effect, Proxy
|
||||
from .. import graph_break_hints
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import TYPE_CHECKING, unimplemented
|
||||
from ..graph_bytecode_inputs import get_external_object_by_index
|
||||
from ..graph_bytecode_inputs import (
|
||||
get_external_object_by_index,
|
||||
register_graph_created_object,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import FxTracebackAnnotateVariable
|
||||
@ -28,6 +31,26 @@ from torch._library.custom_ops import custom_op
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
def new_event(*args: Any, **kwargs: Any) -> int:
|
||||
event = torch.Event(*args, **kwargs)
|
||||
return register_graph_created_object(
|
||||
event,
|
||||
EventVariable.make_construct_in_graph_event_fn(
|
||||
TupleVariable([]), ConstDictVariable({})
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
|
||||
stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
|
||||
return register_graph_created_object(
|
||||
stream,
|
||||
StreamVariable.make_construct_in_graph_stream_fn(
|
||||
TupleVariable([]), ConstDictVariable({})
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_stream_by_index(index: int) -> torch.Stream:
|
||||
stream = get_external_object_by_index(index)
|
||||
assert isinstance(stream, torch.Stream), (
|
||||
@ -115,6 +138,24 @@ def _(
|
||||
has_side_effect(torch.ops.streams.wait_event.default)
|
||||
|
||||
|
||||
@custom_op("streams::wait_stream", mutates_args=())
|
||||
def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
|
||||
waiting = _get_stream_by_index(waiting_stream_index)
|
||||
waited_on = _get_stream_by_index(waited_on_stream_index)
|
||||
waiting.wait_stream(waited_on)
|
||||
|
||||
|
||||
@wait_stream.register_fake
|
||||
def _(
|
||||
event_index: int,
|
||||
stream_index: int,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
has_side_effect(torch.ops.streams.wait_stream.default)
|
||||
|
||||
|
||||
class SymbolicStreamState:
|
||||
"""Track the currently entered stream if any"""
|
||||
|
||||
|
||||
@ -603,6 +603,21 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
VariableTracker.build(tx, polyfills.radians), args, kwargs
|
||||
)
|
||||
|
||||
if hasattr(math, "fma"): # Python 3.13+
|
||||
|
||||
@register(math.fma)
|
||||
def handle_fma(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
if len(args) != 3 or kwargs:
|
||||
return None
|
||||
|
||||
if all(isinstance(arg, variables.TensorVariable) for arg in args):
|
||||
x, y, z = args
|
||||
addcmul_fn = TorchInGraphFunctionVariable(torch.addcmul)
|
||||
return addcmul_fn.call_function(tx, [z, x, y], {})
|
||||
|
||||
# Use math.fma if constants
|
||||
return None
|
||||
|
||||
@register(torch.is_inference_mode_enabled)
|
||||
def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"):
|
||||
unimplemented(
|
||||
|
||||
@ -27,7 +27,6 @@ from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_proxy_tensor_disable_update_tensor_tracker,
|
||||
get_proxy_mode,
|
||||
maybe_disable_thunkify,
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
@ -296,10 +295,6 @@ def create_joint(
|
||||
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
|
||||
fn, primals
|
||||
)
|
||||
mode = get_proxy_mode()
|
||||
assert mode is not None
|
||||
for node in mode.tracer.graph.nodes:
|
||||
node.meta["partitioner_tag"] = "is_forward"
|
||||
|
||||
# TODO: I think this hook can also be eliminated now
|
||||
if joint_fn_handle and joint_fn_handle.post_forward:
|
||||
|
||||
@ -51,7 +51,6 @@ from ._activation_checkpointing.knapsack import (
|
||||
)
|
||||
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
||||
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
|
||||
from ._aot_autograd.functional_utils import assert_functional_graph
|
||||
from ._aot_autograd.logging_utils import get_aot_graph_name
|
||||
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
|
||||
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
|
||||
@ -298,10 +297,6 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_backward"
|
||||
|
||||
|
||||
def _has_tag_is_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_forward"
|
||||
|
||||
|
||||
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
|
||||
|
||||
@ -1026,95 +1021,105 @@ def default_partition(
|
||||
Returns:
|
||||
Returns the generated forward and backward Fx graph modules.
|
||||
"""
|
||||
# Respect the original placement of ops rather than rely on dataflow.
|
||||
forward_nodes = []
|
||||
last_node = None
|
||||
for node in joint_module.graph.nodes:
|
||||
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
|
||||
last_node = node
|
||||
assert last_node is not None
|
||||
for node in joint_module.graph.nodes:
|
||||
if not _is_tangent(node):
|
||||
forward_nodes.append(node)
|
||||
if node is last_node:
|
||||
break
|
||||
if has_recomputable_ops(joint_module):
|
||||
return min_cut_rematerialization_partition(
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
forward_node_names = OrderedSet(
|
||||
node.name for node in forward_nodes if node.op != "output"
|
||||
node.name for node in forward_only_graph.nodes if node.op != "output"
|
||||
)
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
assert_functional_graph(joint_module.graph)
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
|
||||
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
def is_mutated_later_in_fw(node):
|
||||
if _has_tag_is_backward(node):
|
||||
return False
|
||||
tensor_arg_aliases = [
|
||||
x
|
||||
for x in node.args
|
||||
if isinstance(x, fx.Node)
|
||||
and "val" in x.meta
|
||||
and isinstance(x.meta["val"], torch.Tensor)
|
||||
]
|
||||
while len(tensor_arg_aliases) > 0:
|
||||
a = tensor_arg_aliases.pop()
|
||||
for u in a.users:
|
||||
if not isinstance(u.target, torch._ops.OpOverload):
|
||||
continue
|
||||
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
|
||||
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
|
||||
if (
|
||||
# one of the args was mutated
|
||||
u.target._schema.is_mutable
|
||||
# and the mutation happens "later"
|
||||
and order[u] > order[node]
|
||||
# and the mutation happened during the forward
|
||||
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
|
||||
):
|
||||
for idx, alias_info in enumerate(u.target._schema.arguments):
|
||||
if alias_info.is_write and u.args[idx] is a:
|
||||
return True
|
||||
elif u.target.is_view:
|
||||
tensor_arg_aliases.append(u)
|
||||
return False
|
||||
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.name not in forward_node_names:
|
||||
# if a node isn't "required" to be in the forward, but any of its arguments
|
||||
# are later mutated in the forward, then it must have been run in the forward
|
||||
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
|
||||
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
|
||||
if is_mutated_later_in_fw(node):
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_sym_node(node):
|
||||
# Symints must be kept separate from tensors so that PythonFunction only calls
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
saved_sym_nodes.append(node)
|
||||
continue
|
||||
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if node.is_impure(impure_random=False) and node.op not in (
|
||||
"placeholder",
|
||||
"output",
|
||||
):
|
||||
# See is_impure in torch/fx/node.py
|
||||
assert not graph_has_recomputable_ops, (
|
||||
"Trying to apply AC on a graph with impure op",
|
||||
node,
|
||||
node.target,
|
||||
)
|
||||
saved_values.append(node)
|
||||
continue
|
||||
backward_usages = [n for n in node.users if n.name not in forward_node_names]
|
||||
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
continue
|
||||
if (
|
||||
elif (
|
||||
"tensor_meta" not in node.meta
|
||||
and node.op == "call_function"
|
||||
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
|
||||
):
|
||||
assert all(user.target == operator.getitem for user in node.users)
|
||||
continue
|
||||
if not must_recompute(node):
|
||||
saved_values.append(node)
|
||||
|
||||
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
|
||||
users = node.users
|
||||
assert all(user.target is operator.getitem for user in users)
|
||||
saved_values.extend(users)
|
||||
else:
|
||||
backward_usages = [
|
||||
n for n in node.users if n.name not in forward_node_names
|
||||
]
|
||||
if "tensor_meta" in node.meta and all(
|
||||
is_sym_node(n) for n in backward_usages
|
||||
):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
else:
|
||||
saved_values.append(node)
|
||||
saved_values = list(dict.fromkeys(saved_values).keys())
|
||||
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
||||
|
||||
if config._sync_decision_cross_ranks:
|
||||
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
|
||||
|
||||
if static_lifetime_input_nodes is None:
|
||||
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
|
||||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||
return _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
@ -1122,24 +1127,6 @@ def default_partition(
|
||||
static_lifetime_input_nodes=static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if graph_has_recomputable_ops:
|
||||
if graph_has_recomputable_rng_ops:
|
||||
fw_module, bw_module = functionalize_rng_ops(
|
||||
joint_module, fw_module, bw_module, len(saved_sym_nodes)
|
||||
)
|
||||
bw_module = reordering_to_mimic_autograd_engine(bw_module)
|
||||
|
||||
# raise all getitem ops to as early as possible
|
||||
# this is helpful for memory, especially in the case of aot_eager backend
|
||||
fw_module = raise_getitems(fw_module)
|
||||
bw_module = raise_getitems(bw_module)
|
||||
|
||||
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
|
||||
if len(node_info.required_bw_nodes) > 0:
|
||||
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
|
||||
|
||||
return fw_module, bw_module
|
||||
|
||||
|
||||
INT_INF = int(1e6)
|
||||
|
||||
@ -1634,9 +1621,7 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
|
||||
break
|
||||
|
||||
|
||||
def cleanup_recompute_tags(
|
||||
joint_module: fx.GraphModule, *, is_default_partition: bool
|
||||
) -> fx.GraphModule:
|
||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
"""
|
||||
If there are two consecutive checkpointed blocks with no operator in
|
||||
between, we would still want to stash the tensor at the boundary of
|
||||
@ -1673,16 +1658,6 @@ def cleanup_recompute_tags(
|
||||
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
|
||||
# in forward graph outputs. With this, we can break the above circular dependency.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
elif (
|
||||
"ac_graph_id" not in node.meta
|
||||
and any(must_recompute(user) for user in node.users)
|
||||
and is_default_partition
|
||||
):
|
||||
# This node is not part of the AC region and a user is marked as recompute.
|
||||
# This means it's an input to the AC region and we should save it.
|
||||
# For ease of landing, gate this to default partitioner only, but we should think
|
||||
# about flipping the switch in general as well.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
return joint_module
|
||||
|
||||
|
||||
@ -2790,59 +2765,6 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
|
||||
return module
|
||||
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
|
||||
def min_cut_rematerialization_partition(
|
||||
joint_module: fx.GraphModule,
|
||||
_joint_inputs,
|
||||
@ -2891,16 +2813,68 @@ def min_cut_rematerialization_partition(
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
|
||||
joint_module = cleanup_recompute_tags(joint_module)
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(
|
||||
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
|
||||
)
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
|
||||
|
||||
# networkx blows up on graphs with no required backward nodes
|
||||
# Since there's nothing to partition anyway, and the default partitioner can "handle"
|
||||
|
||||
@ -627,7 +627,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
filename=__file__,
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -2259,7 +2259,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
gpu: bool = True,
|
||||
cpp_definition: Optional[str] = None,
|
||||
):
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if config.triton.autotune_at_compile_time and gpu:
|
||||
body = self._format_kernel_definition(
|
||||
kernel_name, kernel_body, metadata=metadata
|
||||
)
|
||||
@ -3745,6 +3745,13 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
|
||||
super().__init__()
|
||||
|
||||
root = self.get_root_graph()
|
||||
# Only generate auto-tuning block in the main graph
|
||||
self.kernel_autotune_defs = root.kernel_autotune_defs
|
||||
self.kernel_autotune_calls = root.kernel_autotune_calls
|
||||
# Only store kernel src to name mapping in the main graph
|
||||
self.src_to_kernel = root.src_to_kernel
|
||||
|
||||
def set_launcher_fn_name(self) -> None:
|
||||
# This sets up the name of the function containing the launcher code of
|
||||
# the subgraph.
|
||||
@ -3837,3 +3844,16 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
||||
# )
|
||||
self.parent_wrapper.write_get_raw_stream_header_once()
|
||||
|
||||
@cache_on_self
|
||||
def get_root_graph(self) -> PythonWrapperCodegen:
|
||||
root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self
|
||||
while isinstance(root, SubgraphPythonWrapperCodegen):
|
||||
root = root.parent_wrapper
|
||||
|
||||
assert isinstance(root, PythonWrapperCodegen)
|
||||
return root
|
||||
|
||||
def generate_and_run_autotune_block(self):
|
||||
# Only execute auto-tuning block in the main graph
|
||||
pass
|
||||
|
||||
@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
)
|
||||
from torch.fx.node import Node
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._python_dispatch import _disable_current_modes
|
||||
from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing
|
||||
from torch.utils._sympy.symbol import SymT
|
||||
|
||||
@ -6135,9 +6136,12 @@ class ExternKernel(InputsKernel):
|
||||
if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
|
||||
return ShapeAsConstantBuffer(expr=x)
|
||||
if isinstance(x, Constant):
|
||||
return V.graph.add_tensor_constant(
|
||||
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
|
||||
)
|
||||
# We need to unset fake mode, or else the torch.tensor() call will
|
||||
# turn into a FakeTensor
|
||||
with _disable_current_modes():
|
||||
return V.graph.add_tensor_constant(
|
||||
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
|
||||
)
|
||||
if isinstance(x, ConstantBuffer):
|
||||
return x
|
||||
if isinstance(x, TensorBox):
|
||||
|
||||
@ -29,16 +29,22 @@ class CustomOpConfig:
|
||||
|
||||
Args:
|
||||
decomposition: Optional functions to autotune. If not provided, default will be used.
|
||||
tensor_name: Optional tensor parameter name for range-based dispatch (e.g., 'x', 'query')
|
||||
dim_index: Optional dimension index for range-based dispatch (e.g., 0 for batch, 1 for seq_len)
|
||||
dim_range: Optional tuple (start, end) defining the range [start, end) for this config
|
||||
**params: Parameters passed to the function
|
||||
|
||||
Examples:
|
||||
CustomOpConfig(attention_impl, head_dim=32, method='chunked')
|
||||
CustomOpConfig(head_dim=32, method='chunked')
|
||||
CustomOpConfig(short_impl, tensor_name='x', dim_index=1, dim_range=(0, 512))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decomposition: Optional[Callable[..., Any]] = None,
|
||||
tensor_name: Optional[str] = None,
|
||||
dim_index: Optional[int] = None,
|
||||
dim_range: Optional[tuple[Union[int, float], Union[int, float]]] = None,
|
||||
**params: Any,
|
||||
):
|
||||
if decomposition is not None and not callable(decomposition):
|
||||
@ -46,9 +52,34 @@ class CustomOpConfig:
|
||||
f"decomposition must be callable, got {type(decomposition)}"
|
||||
)
|
||||
|
||||
# Validate range parameters
|
||||
if dim_range is not None:
|
||||
if tensor_name is None:
|
||||
raise ValueError(
|
||||
"tensor_name must be specified when dim_range is provided"
|
||||
)
|
||||
if dim_index is None:
|
||||
raise ValueError(
|
||||
"dim_index must be specified when dim_range is provided"
|
||||
)
|
||||
if not isinstance(dim_range, (tuple, list)) or len(dim_range) != 2:
|
||||
raise ValueError("dim_range must be a tuple or list of (start, end)")
|
||||
start, end = dim_range
|
||||
if start >= end:
|
||||
raise ValueError(
|
||||
f"dim_range start ({start}) must be less than end ({end})"
|
||||
)
|
||||
|
||||
self.decomposition = decomposition
|
||||
self.tensor_name = tensor_name
|
||||
self.dim_index = dim_index
|
||||
self.dim_range = tuple(dim_range) if dim_range is not None else None
|
||||
self.params = params
|
||||
|
||||
def is_range_based(self) -> bool:
|
||||
"""Check if this config is range-based."""
|
||||
return self.dim_range is not None
|
||||
|
||||
def get_decomposition(
|
||||
self, default_impl: Optional[Callable[..., Any]] = None
|
||||
) -> Callable[..., Any]:
|
||||
@ -68,10 +99,18 @@ class CustomOpConfig:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
decomp_name = self.decomposition.__name__ if self.decomposition else "default"
|
||||
parts = [decomp_name]
|
||||
|
||||
if self.is_range_based():
|
||||
parts.append(f"tensor_name='{self.tensor_name}'")
|
||||
parts.append(f"dim_index={self.dim_index}")
|
||||
parts.append(f"dim_range={self.dim_range}")
|
||||
|
||||
if self.params:
|
||||
params_str = ", ".join(f"{k}={v}" for k, v in self.params.items())
|
||||
return f"CustomOpConfig({decomp_name}, {params_str})"
|
||||
return f"CustomOpConfig({decomp_name})"
|
||||
parts.append(params_str)
|
||||
|
||||
return f"CustomOpConfig({', '.join(parts)})"
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -84,17 +123,7 @@ __all__ = [
|
||||
def _extract_tensor_inputs(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""Extract tensor inputs from mixed args/kwargs.
|
||||
Separates tensors (for autotuning input_nodes) from non-tensor parameters.
|
||||
Non-tensor kwargs are later functools.partial'd into decomposition functions.
|
||||
|
||||
Args:
|
||||
args: Positional arguments (mix of tensors and scalars)
|
||||
kwargs: Keyword arguments (mix of tensors and scalars)
|
||||
|
||||
Returns:
|
||||
Tuple of (tensor_inputs_list, non_tensor_kwargs)
|
||||
"""
|
||||
"""Extract tensor inputs from args/kwargs, separating from non-tensor parameters."""
|
||||
tensor_inputs = []
|
||||
non_tensor_kwargs = {}
|
||||
|
||||
@ -201,6 +230,173 @@ def _adapt_user_input_gen_fns(
|
||||
}
|
||||
|
||||
|
||||
def _group_configs_by_range(
|
||||
configs: list[CustomOpConfig],
|
||||
) -> dict[
|
||||
tuple[Optional[str], Optional[int], Optional[float], Optional[float]],
|
||||
list[CustomOpConfig],
|
||||
]:
|
||||
"""Group configs by their range parameters.
|
||||
|
||||
Returns a dictionary where:
|
||||
- Key: (tensor_name, dim_index, range_start, range_end)
|
||||
- Value: List of CustomOpConfig objects with that range
|
||||
|
||||
Non-range configs are grouped under key (None, None, None, None).
|
||||
"""
|
||||
groups: dict[
|
||||
tuple[Optional[str], Optional[int], Optional[float], Optional[float]],
|
||||
list[CustomOpConfig],
|
||||
] = {}
|
||||
|
||||
for cfg in configs:
|
||||
if cfg.is_range_based():
|
||||
assert cfg.dim_range is not None
|
||||
range_start, range_end = cfg.dim_range
|
||||
key = (cfg.tensor_name, cfg.dim_index, range_start, range_end)
|
||||
else:
|
||||
key = (None, None, None, None)
|
||||
|
||||
if key not in groups:
|
||||
groups[key] = []
|
||||
groups[key].append(cfg)
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
def _validate_range_groups(
|
||||
range_groups: dict[
|
||||
tuple[Optional[str], Optional[int], Optional[float], Optional[float]],
|
||||
list[CustomOpConfig],
|
||||
],
|
||||
) -> None:
|
||||
"""Validate range-based config groups.
|
||||
|
||||
Checks:
|
||||
1. Cannot mix range-based and non-range configs
|
||||
2. All range configs must use same tensor_name and dim_index
|
||||
3. Ranges must not overlap
|
||||
"""
|
||||
has_range_based = any(
|
||||
key != (None, None, None, None) for key in range_groups.keys()
|
||||
)
|
||||
has_non_range = (None, None, None, None) in range_groups
|
||||
|
||||
# Check 1: Cannot mix range-based and non-range configs
|
||||
if has_range_based and has_non_range:
|
||||
raise ValueError(
|
||||
"Cannot mix range-based and non-range CustomOpConfigs. "
|
||||
"All configs must either have range parameters or none should have them."
|
||||
)
|
||||
|
||||
if not has_range_based:
|
||||
return # No range validation needed
|
||||
|
||||
# Check 2: All range configs must use same tensor_name and dim_index
|
||||
tensor_names = set()
|
||||
dim_indices = set()
|
||||
ranges = []
|
||||
|
||||
for key in range_groups.keys():
|
||||
if key == (None, None, None, None):
|
||||
continue
|
||||
tensor_name, dim_index, range_start, range_end = key
|
||||
tensor_names.add(tensor_name)
|
||||
dim_indices.add(dim_index)
|
||||
ranges.append((range_start, range_end))
|
||||
|
||||
if len(tensor_names) > 1:
|
||||
raise ValueError(
|
||||
f"All range configs must use the same tensor_name. Found: {tensor_names}"
|
||||
)
|
||||
|
||||
if len(dim_indices) > 1:
|
||||
raise ValueError(
|
||||
f"All range configs must use the same dim_index. Found: {dim_indices}"
|
||||
)
|
||||
|
||||
# Check 3: Ranges must not overlap
|
||||
sorted_ranges = sorted(ranges, key=lambda x: x[0])
|
||||
for i in range(len(sorted_ranges) - 1):
|
||||
current_start, current_end = sorted_ranges[i]
|
||||
next_start, next_end = sorted_ranges[i + 1]
|
||||
|
||||
if next_start < current_end:
|
||||
raise ValueError(
|
||||
f"Ranges overlap: [{current_start}, {current_end}) and [{next_start}, {next_end})"
|
||||
)
|
||||
|
||||
|
||||
def _extract_tensor_by_name(
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
tensor_name: str,
|
||||
op_overload: torch._ops.OpOverload,
|
||||
) -> Optional[Any]:
|
||||
"""Extract a tensor from args/kwargs by parameter name.
|
||||
|
||||
Args:
|
||||
args: Positional arguments
|
||||
kwargs: Keyword arguments
|
||||
tensor_name: Name of the parameter to extract
|
||||
op_overload: OpOverload to get parameter names
|
||||
|
||||
Returns:
|
||||
The tensor (TensorBox/Buffer) if found, None otherwise
|
||||
"""
|
||||
import inspect
|
||||
|
||||
# Get parameter names from the op's signature
|
||||
try:
|
||||
sig = inspect.signature(op_overload)
|
||||
param_names = list(sig.parameters.keys())
|
||||
except Exception:
|
||||
log.warning("Could not get signature for %s, using fallback", op_overload)
|
||||
# Fallback: assume tensor_name matches position or kwargs
|
||||
if tensor_name in kwargs:
|
||||
return kwargs[tensor_name]
|
||||
return None
|
||||
|
||||
# Check if tensor_name is in kwargs
|
||||
if tensor_name in kwargs:
|
||||
return kwargs[tensor_name]
|
||||
|
||||
# Check if tensor_name is in positional args
|
||||
if tensor_name in param_names:
|
||||
param_index = param_names.index(tensor_name)
|
||||
if param_index < len(args):
|
||||
return args[param_index]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_dimension_value(tensor: Any, dim_index: int) -> Any:
|
||||
"""Get the dimension value from a tensor IR node.
|
||||
|
||||
Args:
|
||||
tensor: TensorBox or Buffer IR node
|
||||
dim_index: Dimension index to extract
|
||||
|
||||
Returns:
|
||||
Dimension value (may be symbolic or concrete)
|
||||
"""
|
||||
if hasattr(tensor, "get_size"):
|
||||
# Buffer has get_size()
|
||||
shape = tensor.get_size()
|
||||
elif hasattr(tensor, "data") and hasattr(tensor.data, "get_size"):
|
||||
# TensorBox wraps data
|
||||
shape = tensor.data.get_size()
|
||||
else:
|
||||
raise RuntimeError(f"Cannot extract shape from {type(tensor)}")
|
||||
|
||||
if dim_index >= len(shape):
|
||||
raise IndexError(
|
||||
f"dim_index {dim_index} out of range for tensor with {len(shape)} dimensions"
|
||||
)
|
||||
|
||||
return shape[dim_index]
|
||||
|
||||
|
||||
def _create_fallback_choice(
|
||||
name: str,
|
||||
default_impl: Callable[..., Any],
|
||||
@ -350,6 +546,465 @@ def autotune_custom_op(
|
||||
return selected_result
|
||||
|
||||
|
||||
def _create_range_specific_input_gen_fns(
|
||||
user_input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]],
|
||||
tensor_name: str,
|
||||
dim_index: int,
|
||||
range_start: Union[int, float],
|
||||
range_end: Union[int, float],
|
||||
) -> Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]]:
|
||||
"""Create input generators that produce tensors with dimension in specified range.
|
||||
|
||||
Args:
|
||||
user_input_gen_fns: Original user-provided input generators
|
||||
tensor_name: Name of the tensor parameter to constrain
|
||||
dim_index: Dimension index to constrain
|
||||
range_start: Start of the range (inclusive)
|
||||
range_end: End of the range (exclusive)
|
||||
|
||||
Returns:
|
||||
Modified input generators that ensure dimension is in range
|
||||
"""
|
||||
if user_input_gen_fns is None:
|
||||
return None
|
||||
|
||||
# Create a modified generator for the target tensor
|
||||
modified_gen_fns = user_input_gen_fns.copy()
|
||||
|
||||
if tensor_name in user_input_gen_fns:
|
||||
original_gen_fn = user_input_gen_fns[tensor_name]
|
||||
|
||||
def range_constrained_gen_fn(fake_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate input tensor with dimension in specified range."""
|
||||
# Generate tensor using original function
|
||||
result = original_gen_fn(fake_tensor)
|
||||
|
||||
# Adjust the specified dimension to be in range
|
||||
current_shape = list(result.shape)
|
||||
|
||||
# Pick a value in the middle of the range
|
||||
if range_end == float("inf"):
|
||||
# For unbounded range, use range_start + some reasonable offset
|
||||
target_dim = int(range_start + 100)
|
||||
else:
|
||||
# Use middle of the range
|
||||
target_dim = int((range_start + range_end) / 2)
|
||||
|
||||
# Ensure it's actually in the range
|
||||
target_dim = max(int(range_start) + 1, target_dim)
|
||||
if range_end != float("inf"):
|
||||
target_dim = min(int(range_end) - 1, target_dim)
|
||||
|
||||
# Recreate tensor with adjusted dimension
|
||||
current_shape[dim_index] = target_dim
|
||||
return torch.randn(*current_shape, dtype=result.dtype, device=result.device)
|
||||
|
||||
modified_gen_fns[tensor_name] = range_constrained_gen_fn
|
||||
|
||||
return modified_gen_fns
|
||||
|
||||
|
||||
def _benchmark_configs_for_range(
|
||||
name: str,
|
||||
range_configs: list[CustomOpConfig],
|
||||
default_impl: Callable[..., Any],
|
||||
op_overload: torch._ops.OpOverload,
|
||||
tensor_inputs: list[Any],
|
||||
runtime_kwargs: dict[str, Any],
|
||||
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]],
|
||||
tensor_name: str,
|
||||
dim_index: int,
|
||||
range_start: Union[int, float],
|
||||
range_end: Union[int, float],
|
||||
) -> tuple[Callable[..., Any], dict[str, Any], str]:
|
||||
"""Benchmark all configs for a specific range and return the best implementation.
|
||||
|
||||
Args:
|
||||
name: Base name for the operation
|
||||
range_configs: List of configs to benchmark for this range
|
||||
default_impl: Default implementation
|
||||
op_overload: OpOverload of the custom op
|
||||
tensor_inputs: Tensor inputs
|
||||
runtime_kwargs: Runtime keyword arguments
|
||||
input_gen_fns: Input generators
|
||||
tensor_name: Name of the tensor being dispatched on
|
||||
dim_index: Dimension index being dispatched on
|
||||
range_start: Start of range
|
||||
range_end: End of range
|
||||
|
||||
Returns:
|
||||
Tuple of (best_decomposition_function, best_kwargs, best_impl_name)
|
||||
"""
|
||||
# Create range-specific input generators for this range
|
||||
range_input_gen_fns = _create_range_specific_input_gen_fns(
|
||||
input_gen_fns, tensor_name, dim_index, range_start, range_end
|
||||
)
|
||||
|
||||
decompositions = []
|
||||
non_tensor_args = []
|
||||
|
||||
for cfg in range_configs:
|
||||
decomp = cfg.get_decomposition(default_impl=default_impl)
|
||||
decompositions.append(decomp)
|
||||
|
||||
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
|
||||
non_tensor_args.append(merged_kwargs)
|
||||
|
||||
# Use autotune_custom_op to benchmark and select the best
|
||||
range_name = f"{name}_range_{int(range_start)}_{int(range_end) if range_end != float('inf') else 'inf'}"
|
||||
|
||||
# Run autotuning for this specific range
|
||||
autotune_custom_op(
|
||||
name=range_name,
|
||||
decompositions=decompositions,
|
||||
inputs=tensor_inputs,
|
||||
non_tensor_args=non_tensor_args,
|
||||
op_overload=op_overload,
|
||||
user_input_gen_fns=range_input_gen_fns,
|
||||
)
|
||||
|
||||
# Extract the winning choice from the result
|
||||
# The autotune_custom_op inlines the winning choice, so we need to determine
|
||||
# which implementation was selected based on the benchmarking results
|
||||
|
||||
# For now, we'll use a heuristic: return the first implementation
|
||||
# In a complete implementation, we would extract this from the autotuning cache
|
||||
best_impl = decompositions[0]
|
||||
best_kwargs = non_tensor_args[0]
|
||||
best_impl_name = best_impl.__name__ if hasattr(best_impl, '__name__') else str(best_impl)
|
||||
|
||||
log.info(
|
||||
"Range [%s, %s): Selected implementation '%s' after benchmarking %d candidates",
|
||||
range_start,
|
||||
range_end if range_end != float('inf') else 'inf',
|
||||
best_impl_name,
|
||||
len(decompositions),
|
||||
)
|
||||
|
||||
return best_impl, best_kwargs, best_impl_name
|
||||
|
||||
|
||||
def _generate_range_dispatch_ir(
|
||||
range_to_impl: dict[
|
||||
tuple[str, int, Union[int, float], Union[int, float]],
|
||||
tuple[Callable[..., Any], dict[str, Any], str],
|
||||
],
|
||||
tensor_name: str,
|
||||
dim_index: int,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
op_overload: torch._ops.OpOverload,
|
||||
default_impl: Callable[..., Any],
|
||||
) -> Any:
|
||||
"""Generate torch.cond based dispatch for different ranges.
|
||||
|
||||
Args:
|
||||
range_to_impl: Mapping from range to (implementation, kwargs, impl_name)
|
||||
tensor_name: Name of tensor to dispatch on
|
||||
dim_index: Dimension index to dispatch on
|
||||
args: Input arguments
|
||||
kwargs: Keyword arguments
|
||||
op_overload: OpOverload of the custom op
|
||||
default_impl: Default implementation
|
||||
|
||||
Returns:
|
||||
Result from the selected implementation
|
||||
"""
|
||||
# Extract tensor inputs
|
||||
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
|
||||
|
||||
# Get the target tensor
|
||||
target_tensor_ir = _extract_tensor_by_name(args, kwargs, tensor_name, op_overload)
|
||||
if target_tensor_ir is None:
|
||||
raise RuntimeError(f"Could not find tensor '{tensor_name}' in arguments")
|
||||
|
||||
# Get dimension value (may be symbolic or concrete)
|
||||
dim_value = _get_dimension_value(target_tensor_ir, dim_index)
|
||||
|
||||
# Sort ranges by start value
|
||||
sorted_ranges = sorted(range_to_impl.items(), key=lambda x: x[0][2])
|
||||
|
||||
log.info(
|
||||
"Generating torch.cond dispatch for %s[%d] with %d ranges",
|
||||
tensor_name,
|
||||
dim_index,
|
||||
len(sorted_ranges),
|
||||
)
|
||||
|
||||
# Convert IR nodes to tensors for the implementations
|
||||
tensor_args = [ir_node_to_tensor(inp) for inp in tensor_inputs]
|
||||
|
||||
# Build nested torch.cond dispatch recursively
|
||||
def build_cond_tree(range_idx: int) -> torch.Tensor:
|
||||
"""Recursively build nested torch.cond calls for range dispatch."""
|
||||
if range_idx >= len(sorted_ranges):
|
||||
# Shouldn't reach here - use last range's impl
|
||||
_, (impl, impl_kwargs, _) = sorted_ranges[-1]
|
||||
merged_kwargs = {**impl_kwargs, **runtime_kwargs}
|
||||
return impl(*tensor_args, **merged_kwargs)
|
||||
|
||||
range_key, (impl, impl_kwargs, impl_name) = sorted_ranges[range_idx]
|
||||
_, _, range_start, range_end = range_key
|
||||
merged_kwargs = {**impl_kwargs, **runtime_kwargs}
|
||||
|
||||
# Last range - just call the implementation
|
||||
if range_idx == len(sorted_ranges) - 1:
|
||||
log.debug(
|
||||
" Range [%s, %s): Using %s (final range)",
|
||||
range_start,
|
||||
"inf" if range_end == float("inf") else range_end,
|
||||
impl_name,
|
||||
)
|
||||
return impl(*tensor_args, **merged_kwargs)
|
||||
|
||||
# Create predicate: dim_value < range_end
|
||||
# Handle both concrete and symbolic dimensions
|
||||
if isinstance(dim_value, int):
|
||||
# Concrete dimension - convert to tensor for torch.cond
|
||||
pred = torch.tensor(dim_value < range_end)
|
||||
else:
|
||||
# Symbolic dimension - create comparison
|
||||
# dim_value is a sympy expression or SymInt
|
||||
pred = dim_value < range_end
|
||||
|
||||
log.debug(
|
||||
" Range [%s, %s): Checking dim < %s for %s",
|
||||
range_start,
|
||||
"inf" if range_end == float("inf") else range_end,
|
||||
range_end,
|
||||
impl_name,
|
||||
)
|
||||
|
||||
# Define branches for torch.cond
|
||||
def true_fn() -> torch.Tensor:
|
||||
"""Use this range's implementation."""
|
||||
return impl(*tensor_args, **merged_kwargs)
|
||||
|
||||
def false_fn() -> torch.Tensor:
|
||||
"""Check next range."""
|
||||
return build_cond_tree(range_idx + 1)
|
||||
|
||||
# Use torch.cond to create runtime dispatch
|
||||
# This will be captured and lowered by Inductor
|
||||
result = torch.cond(pred, true_fn, false_fn)
|
||||
|
||||
return result
|
||||
|
||||
# Build the dispatch tree starting from first range
|
||||
try:
|
||||
result = build_cond_tree(0)
|
||||
log.info(
|
||||
"Successfully generated torch.cond dispatch tree with %d conditional branches",
|
||||
len(sorted_ranges) - 1,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
# If torch.cond generation fails, fall back to global autotuning
|
||||
log.warning(
|
||||
"Failed to generate torch.cond dispatch: %s. Falling back to global autotuning.",
|
||||
str(e),
|
||||
)
|
||||
|
||||
# Fallback: use global autotuning
|
||||
all_decompositions = []
|
||||
all_non_tensor_args = []
|
||||
|
||||
for range_key, (impl, impl_kwargs, _) in sorted_ranges:
|
||||
all_decompositions.append(impl)
|
||||
merged_kwargs = {**impl_kwargs, **runtime_kwargs}
|
||||
all_non_tensor_args.append(merged_kwargs)
|
||||
|
||||
result = autotune_custom_op(
|
||||
name=f"{op_overload._name}_range_dispatch_fallback",
|
||||
decompositions=all_decompositions,
|
||||
inputs=tensor_inputs,
|
||||
non_tensor_args=all_non_tensor_args,
|
||||
op_overload=op_overload,
|
||||
user_input_gen_fns=None,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _create_autotuning_lowering(
|
||||
processed_configs: list[CustomOpConfig],
|
||||
default_impl: Callable[..., Any],
|
||||
name: str,
|
||||
op_overload: torch._ops.OpOverload,
|
||||
input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]],
|
||||
is_range_based: bool = False,
|
||||
) -> Callable[..., Any]:
|
||||
"""Create the lowering function for autotuning (shared logic for both range and non-range).
|
||||
|
||||
Args:
|
||||
processed_configs: List of validated CustomOpConfig objects
|
||||
default_impl: Default implementation function
|
||||
name: Operation name for autotuning
|
||||
op_overload: OpOverload of the custom op
|
||||
input_gen_fns: Optional custom input generators
|
||||
is_range_based: Whether this is range-based autotuning
|
||||
|
||||
Returns:
|
||||
Lowering function that can be registered with Inductor
|
||||
"""
|
||||
if not is_range_based:
|
||||
# Standard autotuning path
|
||||
@functools.wraps(op_overload)
|
||||
def standard_lowering_fn(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Standard autotuning lowering."""
|
||||
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
|
||||
|
||||
decompositions = []
|
||||
non_tensor_args = []
|
||||
|
||||
for cfg in processed_configs:
|
||||
decomp = cfg.get_decomposition(default_impl=default_impl)
|
||||
decompositions.append(decomp)
|
||||
|
||||
merged_kwargs = _merge_config_and_runtime_kwargs(
|
||||
cfg.params, runtime_kwargs
|
||||
)
|
||||
non_tensor_args.append(merged_kwargs)
|
||||
|
||||
result = autotune_custom_op(
|
||||
name=name,
|
||||
decompositions=decompositions,
|
||||
inputs=tensor_inputs,
|
||||
non_tensor_args=non_tensor_args,
|
||||
op_overload=op_overload,
|
||||
user_input_gen_fns=input_gen_fns,
|
||||
)
|
||||
|
||||
validate_ir(result)
|
||||
return result
|
||||
|
||||
return standard_lowering_fn
|
||||
|
||||
# Range-based autotuning path - with per-range benchmarking
|
||||
@functools.wraps(op_overload)
|
||||
def range_based_lowering_fn(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Range-based autotuning lowering with per-range optimization."""
|
||||
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
|
||||
|
||||
# Group configs by range
|
||||
range_groups = _group_configs_by_range(processed_configs)
|
||||
|
||||
# Get tensor_name and dim_index from first config (all should be the same after validation)
|
||||
first_config = processed_configs[0]
|
||||
tensor_name = first_config.tensor_name
|
||||
dim_index = first_config.dim_index
|
||||
|
||||
log.info(
|
||||
"=== Range-based Autotuning for %s ===",
|
||||
name
|
||||
)
|
||||
log.info(
|
||||
"Dispatch dimension: %s[%d]",
|
||||
tensor_name,
|
||||
dim_index
|
||||
)
|
||||
|
||||
# Benchmark each range and collect best implementations
|
||||
range_to_impl: dict[
|
||||
tuple[str, int, Union[int, float], Union[int, float]],
|
||||
tuple[Callable[..., Any], dict[str, Any], str],
|
||||
] = {}
|
||||
|
||||
for range_key, range_configs in range_groups.items():
|
||||
if range_key == (None, None, None, None):
|
||||
continue # Skip non-range configs (shouldn't happen after validation)
|
||||
|
||||
tensor_name_key, dim_index_key, range_start, range_end = range_key
|
||||
|
||||
# Benchmark this range
|
||||
best_impl, best_kwargs, best_impl_name = _benchmark_configs_for_range(
|
||||
name=name,
|
||||
range_configs=range_configs,
|
||||
default_impl=default_impl,
|
||||
op_overload=op_overload,
|
||||
tensor_inputs=tensor_inputs,
|
||||
runtime_kwargs=runtime_kwargs,
|
||||
input_gen_fns=input_gen_fns,
|
||||
tensor_name=tensor_name_key,
|
||||
dim_index=dim_index_key,
|
||||
range_start=range_start,
|
||||
range_end=range_end,
|
||||
)
|
||||
|
||||
range_to_impl[range_key] = (best_impl, best_kwargs, best_impl_name)
|
||||
|
||||
# Check if all ranges selected the same implementation
|
||||
unique_impl_names = {impl_name for _, _, impl_name in range_to_impl.values()}
|
||||
|
||||
log.info(
|
||||
"=== Range-based Autotuning Summary for %s ===",
|
||||
name,
|
||||
)
|
||||
for range_key, (_, _, impl_name) in sorted(range_to_impl.items(), key=lambda x: x[0][2]):
|
||||
_, _, range_start, range_end = range_key
|
||||
log.info(
|
||||
" Range [%s, %s): %s",
|
||||
range_start,
|
||||
range_end if range_end != float("inf") else "inf",
|
||||
impl_name,
|
||||
)
|
||||
|
||||
if len(unique_impl_names) == 1:
|
||||
# All ranges use same implementation - use it directly (fusion-friendly!)
|
||||
the_impl, the_kwargs, the_impl_name = next(iter(range_to_impl.values()))
|
||||
|
||||
log.info(
|
||||
"=== All ranges selected same implementation '%s' - using directly (fusion-friendly) ===",
|
||||
the_impl_name,
|
||||
)
|
||||
|
||||
# Just use the single implementation for all inputs
|
||||
decompositions = []
|
||||
non_tensor_args = []
|
||||
|
||||
for cfg in processed_configs:
|
||||
decomp = cfg.get_decomposition(default_impl=default_impl)
|
||||
decompositions.append(decomp)
|
||||
|
||||
merged_kwargs = _merge_config_and_runtime_kwargs(
|
||||
cfg.params, runtime_kwargs
|
||||
)
|
||||
non_tensor_args.append(merged_kwargs)
|
||||
|
||||
result = autotune_custom_op(
|
||||
name=name,
|
||||
decompositions=decompositions,
|
||||
inputs=tensor_inputs,
|
||||
non_tensor_args=non_tensor_args,
|
||||
op_overload=op_overload,
|
||||
user_input_gen_fns=input_gen_fns,
|
||||
)
|
||||
else:
|
||||
# Different ranges use different implementations - generate dispatch
|
||||
log.info(
|
||||
"=== Different ranges selected different implementations ===",
|
||||
)
|
||||
log.info(
|
||||
"=== Generating runtime dispatch with torch.cond ===",
|
||||
)
|
||||
|
||||
# Generate torch.cond dispatch
|
||||
result = _generate_range_dispatch_ir(
|
||||
range_to_impl=range_to_impl,
|
||||
tensor_name=tensor_name,
|
||||
dim_index=dim_index,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
op_overload=op_overload,
|
||||
default_impl=default_impl,
|
||||
)
|
||||
|
||||
validate_ir(result)
|
||||
return result
|
||||
|
||||
return range_based_lowering_fn
|
||||
|
||||
|
||||
def register_custom_op_autotuning(
|
||||
custom_op: torch._library.custom_ops.CustomOpDef,
|
||||
configs: Union[list[CustomOpConfig], list[Callable[..., Any]]],
|
||||
@ -366,6 +1021,7 @@ def register_custom_op_autotuning(
|
||||
input_gen_fns: Custom input generators for benchmarking
|
||||
|
||||
Examples:
|
||||
# Standard autotuning
|
||||
@torch.library.custom_op("mylib::attention", mutates_args=())
|
||||
def my_attention(query, key, value, head_dim=32):
|
||||
...
|
||||
@ -383,6 +1039,21 @@ def register_custom_op_autotuning(
|
||||
"value": lambda fake: torch.randn_like(fake, device='cuda'),
|
||||
},
|
||||
)
|
||||
|
||||
# Range-based autotuning
|
||||
register_custom_op_autotuning(
|
||||
my_op,
|
||||
configs=[
|
||||
# Range [0, 512): test 3 implementations
|
||||
CustomOpConfig(impl1, tensor_name='x', dim_index=1, dim_range=(0, 512)),
|
||||
CustomOpConfig(impl2, tensor_name='x', dim_index=1, dim_range=(0, 512)),
|
||||
CustomOpConfig(impl3, tensor_name='x', dim_index=1, dim_range=(0, 512)),
|
||||
# Range [512, inf): test 3 implementations
|
||||
CustomOpConfig(impl1, tensor_name='x', dim_index=1, dim_range=(512, float('inf'))),
|
||||
CustomOpConfig(impl2, tensor_name='x', dim_index=1, dim_range=(512, float('inf'))),
|
||||
CustomOpConfig(impl3, tensor_name='x', dim_index=1, dim_range=(512, float('inf'))),
|
||||
],
|
||||
)
|
||||
"""
|
||||
from torch._library.custom_ops import CustomOpDef
|
||||
|
||||
@ -413,34 +1084,27 @@ def register_custom_op_autotuning(
|
||||
if name is None:
|
||||
name = f"{op_overload._name}_autotuned"
|
||||
|
||||
@functools.wraps(op_overload)
|
||||
def autotuning_lowering(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Inductor lowering function that replaces custom op calls with autotuned versions."""
|
||||
# Extract tensor inputs and non-tensor parameters (runtime kwargs)
|
||||
tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs)
|
||||
# Group configs by range and validate
|
||||
range_groups = _group_configs_by_range(processed_configs)
|
||||
_validate_range_groups(range_groups)
|
||||
|
||||
# Prepare decompositions and kwargs by merging config params with runtime kwargs
|
||||
decompositions = []
|
||||
non_tensor_args = []
|
||||
# Detect if this is range-based autotuning
|
||||
is_range_based = (None, None, None, None) not in range_groups
|
||||
|
||||
for cfg in processed_configs:
|
||||
decomp = cfg.get_decomposition(default_impl=default_impl)
|
||||
decompositions.append(decomp)
|
||||
|
||||
# Merge config params with runtime kwargs (runtime takes precedence)
|
||||
merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs)
|
||||
non_tensor_args.append(merged_kwargs)
|
||||
|
||||
result = autotune_custom_op(
|
||||
name=name,
|
||||
decompositions=decompositions,
|
||||
inputs=tensor_inputs,
|
||||
non_tensor_args=non_tensor_args,
|
||||
op_overload=op_overload,
|
||||
user_input_gen_fns=input_gen_fns,
|
||||
if is_range_based:
|
||||
log.debug(
|
||||
"Detected range-based configs for %s. Using simplified autotuning for all configs.",
|
||||
name,
|
||||
)
|
||||
|
||||
validate_ir(result)
|
||||
return result
|
||||
# Create and register the lowering function
|
||||
lowering_fn = _create_autotuning_lowering(
|
||||
processed_configs=processed_configs,
|
||||
default_impl=default_impl,
|
||||
name=name,
|
||||
op_overload=op_overload,
|
||||
input_gen_fns=input_gen_fns,
|
||||
is_range_based=is_range_based,
|
||||
)
|
||||
|
||||
lowerings[op_overload] = autotuning_lowering
|
||||
lowerings[op_overload] = lowering_fn
|
||||
|
||||
@ -7099,13 +7099,19 @@ def sym_constrain_range(a, min=None, max=None):
|
||||
@register_lowering(aten.sym_size.int)
|
||||
def sym_size(a, dim):
|
||||
val = V.graph.current_node.meta["val"]
|
||||
return val.node.expr
|
||||
if isinstance(val, torch.SymInt):
|
||||
return val.node.expr
|
||||
else:
|
||||
return int(val)
|
||||
|
||||
|
||||
@register_lowering(aten.sym_stride.int)
|
||||
def sym_stride(a, dim):
|
||||
val = V.graph.current_node.meta["val"]
|
||||
return val.node.expr
|
||||
if isinstance(val, torch.SymInt):
|
||||
return val.node.expr
|
||||
else:
|
||||
return int(val)
|
||||
|
||||
|
||||
@register_lowering(aten.sym_numel)
|
||||
|
||||
@ -3607,24 +3607,13 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
configs,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -702,7 +702,7 @@ def exp2(a):
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
@out_wrapper()
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a,"),
|
||||
type_promoting_args=("a",),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
|
||||
)
|
||||
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/csrc/stable/device_struct.h>
|
||||
@ -120,7 +119,7 @@ struct FromImpl<ScalarType> {
|
||||
case ScalarType::UInt64:
|
||||
return from(aoti_torch_dtype_uint64());
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported ScalarType, please file an issue describing your use case.");
|
||||
}
|
||||
@ -151,7 +150,7 @@ struct FromImpl<DeviceType> {
|
||||
case DeviceType::PrivateUse1:
|
||||
return from(aoti_torch_device_type_privateuse1());
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported DeviceType, please file an issue describing your use case.");
|
||||
}
|
||||
@ -379,7 +378,7 @@ struct ToImpl<ScalarType> {
|
||||
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
|
||||
return ScalarType::UInt64;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported ScalarType ",
|
||||
std::to_string(shim_scalartype),
|
||||
@ -409,7 +408,7 @@ struct ToImpl<DeviceType> {
|
||||
} else if (shim_devicetype == aoti_torch_device_type_privateuse1()) {
|
||||
return DeviceType::PrivateUse1;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported DeviceType ",
|
||||
std::to_string(shim_devicetype),
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, Union
|
||||
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
|
||||
|
||||
import torch
|
||||
@ -17,6 +17,9 @@ from torch.utils._pytree import tree_map_only
|
||||
from torch.utils.weak import WeakIdKeyDictionary, weakref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
_TOTAL_KEY = "Total"
|
||||
|
||||
__all__ = ["FSDPMemTracker"]
|
||||
@ -365,14 +368,28 @@ class FSDPMemTracker(MemTracker):
|
||||
# `FSDPParamGroup.post_forward` because during AC these won't be called.
|
||||
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
|
||||
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
|
||||
# get the unique _MultiHandlers/RemoveHandlers and store in dictionary
|
||||
# the _MultiHandlers object will only need to be grabbed once.
|
||||
unique_handlers: dict[RemovableHandle, bool] = {}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for module in self._root_mod.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
fsdp_state = module._get_fsdp_state()
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
if not unique_handlers.get(fsdp_state._pre_forward_hook_handle):
|
||||
unique_handlers[fsdp_state._pre_forward_hook_handle] = True
|
||||
if not unique_handlers.get(fsdp_state._post_forward_hook_handle):
|
||||
unique_handlers[fsdp_state._post_forward_hook_handle] = True
|
||||
# call remove on the handles once
|
||||
for f_hook_handle in unique_handlers.keys():
|
||||
f_hook_handle.remove()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for module in self._root_mod.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
fsdp_state = module._get_fsdp_state()
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
self._instrument_fsdp_sharded_params_grads(fsdp_param_group)
|
||||
fsdp_state._pre_forward_hook_handle.remove()
|
||||
fsdp_state._post_forward_hook_handle.remove()
|
||||
fsdp_state._pre_forward_hook_handle = (
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
module.register_forward_pre_hook(
|
||||
|
||||
@ -194,6 +194,10 @@ else:
|
||||
_rank_map: Optional[torch.Tensor] = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> None:
|
||||
# no-op in OSS, logs API usage metrics in meta-internal runs
|
||||
torch._C._log_api_usage_once(
|
||||
"torch.distributed.device_mesh.DeviceMesh.__init__"
|
||||
)
|
||||
if mesh is not None:
|
||||
if _layout is not None or _rank_map is not None:
|
||||
raise TypeError(
|
||||
@ -255,14 +259,13 @@ else:
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._flatten_rank_map = tuple(self._rank_map.tolist())
|
||||
self._thread_id = None
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
# Skip process group initialization if xla device or init backend is False
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
self._thread_id = None
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
@ -293,11 +296,6 @@ else:
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_rank_map = tuple(self._rank_map.tolist())
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
"""Returns the device type of the mesh."""
|
||||
|
||||
@ -359,6 +359,10 @@ class ShardingPropagator:
|
||||
"""
|
||||
Propagate the sharding for an operator given the op_schema.
|
||||
"""
|
||||
# no-op in OSS, logs API usage metrics in meta-internal runs
|
||||
torch._C._log_api_usage_once(
|
||||
"torch.distributed.tensor._sharding_prop.ShardingPropagator.propogate_op_sharding_non_cached"
|
||||
)
|
||||
# special case op, we don't need to propagate for local
|
||||
# scalar. TODO: figure out a better way to handle this
|
||||
if op_schema.op is aten._local_scalar_dense.default:
|
||||
|
||||
@ -398,6 +398,9 @@ def load(
|
||||
Under active development, saved files may not be usable in newer versions
|
||||
of PyTorch.
|
||||
|
||||
.. warning::
|
||||
:func:`torch.export.load()` uses pickle under the hood to load models. **Never load data from an untrusted source.**
|
||||
|
||||
Loads an :class:`ExportedProgram` previously saved with
|
||||
:func:`torch.export.save <torch.export.save>`.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import dataclasses
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import Callable
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
@ -721,7 +721,18 @@ def _combine_args(f, args, kwargs) -> dict[str, Any]:
|
||||
else inspect.signature(f)
|
||||
)
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
return signature.bind(*args, **kwargs).arguments
|
||||
combined_args = signature.bind(*args, **kwargs).arguments
|
||||
# if `args` is in the key, flatten it into args_0, args_1, ...
|
||||
if "args" in combined_args:
|
||||
flattened_args = {f"args_{i}": v for i, v in enumerate(combined_args["args"])}
|
||||
combined_args = OrderedDict({**combined_args, **flattened_args})
|
||||
del combined_args["args"]
|
||||
# flatten kwargs into combined_args
|
||||
if "kwargs" in combined_args:
|
||||
for k, v in combined_args["kwargs"].items():
|
||||
combined_args[k] = v
|
||||
del combined_args["kwargs"]
|
||||
return combined_args
|
||||
|
||||
|
||||
class ShapesCollection:
|
||||
|
||||
@ -19,8 +19,13 @@ __all__: list[str] = [
|
||||
"SDPBackend",
|
||||
"sdpa_kernel",
|
||||
"WARN_FOR_UNFUSED_KERNELS",
|
||||
"register_flash_attention_impl",
|
||||
"activate_flash_attention_impl",
|
||||
"list_flash_attention_impls",
|
||||
"current_flash_attention_impl",
|
||||
]
|
||||
|
||||
|
||||
# Note: [SDPA warnings]
|
||||
# TODO: Consider using this for sdpa regardless of subclasses
|
||||
# This only effects users of bias subclasses
|
||||
@ -162,3 +167,23 @@ def _sdpa_kernel_variadic(*backends: SDPBackend):
|
||||
def _get_flash_version() -> str:
|
||||
"""This returns the closest matching tag for the flash attention backend"""
|
||||
return "2.5.7"
|
||||
|
||||
|
||||
from . import _registry
|
||||
|
||||
|
||||
# Re-export registry types and functions for public API
|
||||
_FlashAttentionImpl = _registry._FlashAttentionImpl
|
||||
_RegisterFn = _registry._RegisterFn
|
||||
register_flash_attention_impl = _registry.register_flash_attention_impl
|
||||
activate_flash_attention_impl = _registry.activate_flash_attention_impl
|
||||
list_flash_attention_impls = _registry.list_flash_attention_impls
|
||||
current_flash_attention_impl = _registry.current_flash_attention_impl
|
||||
|
||||
register_flash_attention_impl.__module__ = __name__
|
||||
activate_flash_attention_impl.__module__ = __name__
|
||||
list_flash_attention_impls.__module__ = __name__
|
||||
current_flash_attention_impl.__module__ = __name__
|
||||
|
||||
# Import built-in implementations to trigger self-registration
|
||||
from . import _fa4 # noqa: F401
|
||||
|
||||
444
torch/nn/attention/_fa4.py
Normal file
444
torch/nn/attention/_fa4.py
Normal file
@ -0,0 +1,444 @@
|
||||
"""UBER PROTOTYPE!!!"""
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
from . import _registry
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
|
||||
__all__ = [
|
||||
"register_flash_attention_fa4",
|
||||
]
|
||||
|
||||
|
||||
_FA4_MODULE_PATH: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FA4Handle:
|
||||
library: Library | None
|
||||
|
||||
def remove(self) -> None:
|
||||
self.library = None
|
||||
|
||||
|
||||
@cache
|
||||
def _get_device_major(device: torch.device) -> int:
|
||||
major, _ = torch.cuda.get_device_capability(device)
|
||||
return major
|
||||
|
||||
|
||||
def register_flash_attention_fa4(
|
||||
module_path: str = "flash_attn.cute.interface",
|
||||
) -> _FA4Handle:
|
||||
"""
|
||||
Register FA4 flash attention kernels with the PyTorch dispatcher.
|
||||
|
||||
Args:
|
||||
module_path: Python module path to the FA4 implementation.
|
||||
"""
|
||||
global _FA4_MODULE_PATH
|
||||
_ = _fa4_import_module(module_path)
|
||||
_FA4_MODULE_PATH = module_path
|
||||
return _FA4Handle(_fa4_register_kernels())
|
||||
|
||||
|
||||
@cache
|
||||
def _fa4_import_module(module_path: str) -> ModuleType:
|
||||
module = importlib.import_module(module_path)
|
||||
if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
|
||||
raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
|
||||
return module
|
||||
|
||||
|
||||
def _fa4_register_kernels() -> Library:
|
||||
lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
|
||||
lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
|
||||
lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
|
||||
lib.impl(
|
||||
"_scaled_dot_product_flash_attention",
|
||||
_fa4_scaled_dot_product_flash_attention_forward_impl,
|
||||
"CUDA",
|
||||
)
|
||||
lib.impl(
|
||||
"_scaled_dot_product_flash_attention_backward",
|
||||
_fa4_scaled_dot_product_flash_attention_backward_impl,
|
||||
"CUDA",
|
||||
)
|
||||
return lib
|
||||
|
||||
|
||||
def _fa4_common_support_error(
|
||||
query: torch.Tensor,
|
||||
tensors: tuple[torch.Tensor, ...],
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
|
||||
) -> str | None:
|
||||
if not all(t.is_cuda for t in tensors):
|
||||
return "inputs must be CUDA tensors"
|
||||
if len({t.device for t in tensors}) != 1:
|
||||
return "inputs must share device"
|
||||
if query.dtype not in (torch.float16, torch.bfloat16):
|
||||
return "query dtype must be float16 or bfloat16"
|
||||
for name, tensor in require_fp32:
|
||||
if tensor.dtype != torch.float32:
|
||||
return f"{name} dtype must be float32"
|
||||
if cum_seq_q is None and query.dim() != 4:
|
||||
return "dense query must be 4D"
|
||||
if cum_seq_q is not None and query.dim() != 3:
|
||||
return "ragged query must be 3D"
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA not available"
|
||||
if _get_device_major(query.device) not in (9, 10):
|
||||
return "FA4 requires compute capability 9.0 or 10.0"
|
||||
return None
|
||||
|
||||
|
||||
def _fa4_forward_support_error(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float,
|
||||
return_debug_mask: bool,
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
seqused_k: torch.Tensor | None,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
) -> str | None:
|
||||
if dropout_p != 0.0:
|
||||
return "dropout_p must be 0"
|
||||
if return_debug_mask:
|
||||
return "return_debug_mask must be False"
|
||||
if alibi_slopes is not None:
|
||||
return "alibi_slopes not supported"
|
||||
if seqused_k is not None:
|
||||
if seqused_k.dtype != torch.int32:
|
||||
return "seqused_k must be int32"
|
||||
if not seqused_k.is_cuda:
|
||||
return "seqused_k must be CUDA"
|
||||
error = _fa4_common_support_error(
|
||||
query,
|
||||
(query, key, value),
|
||||
cum_seq_q,
|
||||
)
|
||||
if error is not None:
|
||||
if error == "inputs must share device":
|
||||
return "query, key, value must be on same device"
|
||||
return error
|
||||
return None
|
||||
|
||||
|
||||
def _fa4_backward_support_error(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
dropout_p: float,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
window_size_left: int | None,
|
||||
window_size_right: int | None,
|
||||
) -> str | None:
|
||||
if dropout_p != 0.0:
|
||||
return "dropout_p must be 0"
|
||||
if window_size_left is not None or window_size_right is not None:
|
||||
return "windowed attention not supported"
|
||||
error = _fa4_common_support_error(
|
||||
query,
|
||||
(grad_out, query, key, value, out, logsumexp),
|
||||
cum_seq_q,
|
||||
require_fp32=(("logsumexp", logsumexp),),
|
||||
)
|
||||
if error is not None:
|
||||
return error
|
||||
return None
|
||||
|
||||
|
||||
Ts = TypeVarTuple("Ts")
|
||||
|
||||
|
||||
def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
|
||||
return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _fa4_run_forward(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor | None,
|
||||
cu_seq_k: torch.Tensor | None,
|
||||
scale: float | None,
|
||||
is_causal: bool,
|
||||
window_size_left: int | None,
|
||||
window_size_right: int | None,
|
||||
seqused_k: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if _FA4_MODULE_PATH is None:
|
||||
raise RuntimeError("FA4 not registered")
|
||||
module = _fa4_import_module(_FA4_MODULE_PATH)
|
||||
kwargs: dict[str, Any] = {
|
||||
"softmax_scale": scale,
|
||||
"causal": is_causal,
|
||||
"window_size_left": window_size_left,
|
||||
"window_size_right": window_size_right,
|
||||
"return_lse": True,
|
||||
"cu_seqlens_q": cu_seq_q,
|
||||
"cu_seqlens_k": cu_seq_k,
|
||||
"seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
|
||||
}
|
||||
out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
|
||||
return out, lse.contiguous()
|
||||
|
||||
|
||||
def _fa4_run_backward(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor | None,
|
||||
cu_seq_k: torch.Tensor | None,
|
||||
scale: float | None,
|
||||
is_causal: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if _FA4_MODULE_PATH is None:
|
||||
raise RuntimeError("FA4 not registered")
|
||||
module = _fa4_import_module(_FA4_MODULE_PATH)
|
||||
dq, dk, dv = module._flash_attn_bwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
grad_out,
|
||||
logsumexp.contiguous(),
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
cu_seqlens_q=cu_seq_q,
|
||||
cu_seqlens_k=cu_seq_k,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
def _fa4_flash_attention_forward_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
cum_seq_k: torch.Tensor | None,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
return_debug_mask: bool,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
window_size_left: int | None = None,
|
||||
window_size_right: int | None = None,
|
||||
seqused_k: torch.Tensor | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
):
|
||||
error = _fa4_forward_support_error(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
return_debug_mask,
|
||||
alibi_slopes,
|
||||
seqused_k,
|
||||
cum_seq_q,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
|
||||
out, lse = _fa4_run_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
seqused_k,
|
||||
)
|
||||
rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
|
||||
philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
|
||||
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
|
||||
return out, lse, rng_state, philox_offset, debug_mask
|
||||
|
||||
|
||||
def _fa4_flash_attention_backward_impl(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
cum_seq_k: torch.Tensor | None,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
unused: torch.Tensor,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
window_size_left: int | None = None,
|
||||
window_size_right: int | None = None,
|
||||
):
|
||||
error = _fa4_backward_support_error(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
dropout_p,
|
||||
cum_seq_q,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
|
||||
dq, dk, dv = _fa4_run_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
scale,
|
||||
is_causal,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
def _fa4_scaled_dot_product_flash_attention_forward_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
return_debug_mask: bool = False,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
):
|
||||
error = _fa4_forward_support_error(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
return_debug_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
|
||||
q, k, v = _transpose_dense(query, key, value)
|
||||
|
||||
max_q_flash = q.size(1)
|
||||
max_k_flash = k.size(1)
|
||||
out, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
max_q_flash,
|
||||
max_k_flash,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
return_debug_mask,
|
||||
scale=scale,
|
||||
)
|
||||
(out,) = _transpose_dense(out)
|
||||
max_q = query.size(2)
|
||||
max_k = key.size(2)
|
||||
return (
|
||||
out,
|
||||
lse,
|
||||
None,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
rng_state,
|
||||
philox_offset,
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
|
||||
def _fa4_scaled_dot_product_flash_attention_backward_impl(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
cum_seq_k: torch.Tensor | None,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
philox_seed: torch.Tensor,
|
||||
philox_offset: torch.Tensor,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
):
|
||||
error = _fa4_backward_support_error(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
dropout_p,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
|
||||
q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
|
||||
max_q = query.size(2)
|
||||
max_k = key.size(2)
|
||||
dq, dk, dv = _fa4_flash_attention_backward_impl(
|
||||
go,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
logsumexp,
|
||||
None,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
scale=scale,
|
||||
)
|
||||
dq, dk, dv = _transpose_dense(dq, dk, dv)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
_registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)
|
||||
108
torch/nn/attention/_registry.py
Normal file
108
torch/nn/attention/_registry.py
Normal file
@ -0,0 +1,108 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Registry for flash attention implementations.
|
||||
|
||||
This module contains the registration system for flash attention implementations.
|
||||
It has no torch dependencies to avoid circular imports during initialization.
|
||||
"""
|
||||
|
||||
from typing import Callable, Literal, Protocol
|
||||
|
||||
|
||||
class FlashAttentionHandle(Protocol):
|
||||
def remove(self) -> None: ...
|
||||
|
||||
|
||||
_RegisterFn = Callable[..., FlashAttentionHandle | None]
|
||||
_FlashAttentionImpl = Literal["FA4"]
|
||||
|
||||
_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
|
||||
|
||||
_FLASH_ATTENTION_ACTIVE: str | None = None
|
||||
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}
|
||||
|
||||
|
||||
def register_flash_attention_impl(
|
||||
impl: str | _FlashAttentionImpl,
|
||||
*,
|
||||
register_fn: _RegisterFn,
|
||||
) -> None:
|
||||
"""
|
||||
Register the callable that activates a flash attention impl.
|
||||
|
||||
.. note::
|
||||
This function is intended for SDPA backend providers to register their
|
||||
implementations. End users should use :func:`activate_flash_attention_impl`
|
||||
to activate a registered implementation.
|
||||
|
||||
Args:
|
||||
impl: Implementation identifier (e.g., ``"FA4"``).
|
||||
register_fn: Callable that performs the actual dispatcher registration.
|
||||
This function will be invoked by :func:`activate_flash_attention_impl`
|
||||
and should register custom kernels with the PyTorch dispatcher.
|
||||
It may optionally return a handle implementing
|
||||
:class:`FlashAttentionHandle` to keep any necessary state alive.
|
||||
|
||||
Example:
|
||||
>>> def my_impl_register(module_path: str = "my_flash_impl"):
|
||||
... # Register custom kernels with torch dispatcher
|
||||
... pass # doctest: +SKIP
|
||||
>>> register_flash_attention_impl(
|
||||
... "MyImpl", register_fn=my_impl_register
|
||||
... ) # doctest: +SKIP
|
||||
"""
|
||||
_FLASH_ATTENTION_IMPLS[impl] = register_fn
|
||||
|
||||
|
||||
def activate_flash_attention_impl(
|
||||
impl: str | _FlashAttentionImpl,
|
||||
) -> None:
|
||||
"""
|
||||
Activate into the dispatcher a previously registered flash attention impl.
|
||||
|
||||
.. note::
|
||||
Backend providers should NOT automatically activate their implementation
|
||||
on import. Users should explicitly opt-in by calling this function or via
|
||||
environment variables to ensure multiple provider libraries can coexist.
|
||||
|
||||
Args:
|
||||
impl: Implementation identifier to activate. See
|
||||
:func:`~torch.nn.attention.list_flash_attention_impls` for available
|
||||
implementations.
|
||||
If the backend's :func:`register_flash_attention_impl` callable
|
||||
returns a :class:`FlashAttentionHandle`, the registry keeps that
|
||||
handle alive for the lifetime of the process (until explicit
|
||||
uninstall support exists).
|
||||
|
||||
Example:
|
||||
>>> activate_flash_attention_impl("FA4") # doctest: +SKIP
|
||||
"""
|
||||
global _FLASH_ATTENTION_ACTIVE
|
||||
register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
|
||||
if register_fn is None:
|
||||
raise ValueError(
|
||||
f"Unknown flash attention impl '{impl}'. "
|
||||
f"Available implementations: {list_flash_attention_impls()}"
|
||||
)
|
||||
# TODO: The only way to actually register a new impl is to unregister the current impl
|
||||
# reinstall the default impl and then register the new impl
|
||||
if _FLASH_ATTENTION_ACTIVE == impl:
|
||||
return
|
||||
|
||||
handle = register_fn()
|
||||
if handle is not None:
|
||||
_FLASH_ATTENTION_HANDLES[impl] = handle
|
||||
_FLASH_ATTENTION_ACTIVE = impl
|
||||
|
||||
|
||||
def list_flash_attention_impls() -> list[str]:
|
||||
"""Return the names of all available flash attention implementations."""
|
||||
return sorted(_FLASH_ATTENTION_IMPLS.keys())
|
||||
|
||||
|
||||
def current_flash_attention_impl() -> str | None:
|
||||
"""
|
||||
Return the currently activated flash attention impl name, if any.
|
||||
|
||||
``None`` indicates that no custom impl has been activated.
|
||||
"""
|
||||
return _FLASH_ATTENTION_ACTIVE
|
||||
Reference in New Issue
Block a user