mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
Compare commits
1 Commits
ciflow/tru
...
zhxchen17/
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d97899283 |
@ -96,6 +96,7 @@ function pip_build_and_install() {
|
||||
python3 -m pip wheel \
|
||||
--no-build-isolation \
|
||||
--no-deps \
|
||||
--no-use-pep517 \
|
||||
-w "${wheel_dir}" \
|
||||
"${build_target}"
|
||||
fi
|
||||
|
||||
2
.github/actionlint.yaml
vendored
2
.github/actionlint.yaml
vendored
@ -63,7 +63,7 @@ self-hosted-runner:
|
||||
- linux.rocm.gpu.gfx942.1
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
- linux.rocm.gfx942.docker-cache
|
||||
- rocm-docker
|
||||
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
|
||||
- macos-m1-stable
|
||||
- macos-m1-14
|
||||
|
||||
55
.github/workflows/docker-cache-mi300.yml
vendored
Normal file
55
.github/workflows/docker-cache-mi300.yml
vendored
Normal file
@ -0,0 +1,55 @@
|
||||
name: docker-cache-mi300
|
||||
|
||||
on:
|
||||
# run every 6 hours
|
||||
schedule:
|
||||
- cron: 0 0,6,12,18 * * *
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
runs-on: rocm-docker
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
no-sudo: true
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
push: false
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Tar and upload to S3 bucket
|
||||
run: |
|
||||
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress
|
||||
108
.github/workflows/docker-cache-rocm.yml
vendored
108
.github/workflows/docker-cache-rocm.yml
vendored
@ -1,108 +0,0 @@
|
||||
name: docker-cache-rocm
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [docker-builds]
|
||||
# TODO: Uncomment before merging
|
||||
#branches: [main, release]
|
||||
types:
|
||||
- completed
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
actions: read
|
||||
|
||||
jobs:
|
||||
download-docker-builds-artifacts:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: download-docker-builds-artifacts
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
|
||||
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
|
||||
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
|
||||
steps:
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4.1.7
|
||||
with:
|
||||
run-id: ${{ github.event.workflow_run.id }}
|
||||
path: ./docker-builds-artifacts
|
||||
merge-multiple: true
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Process artifacts
|
||||
id: process-artifacts
|
||||
run: |
|
||||
ls -R ./docker-builds-artifacts
|
||||
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
|
||||
cat "${GITHUB_OUTPUT}"
|
||||
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
needs: download-docker-builds-artifacts
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
runner: [linux.rocm.gfx942.docker-cache]
|
||||
docker-image: [
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
|
||||
]
|
||||
runs-on: "${{ matrix.runner }}"
|
||||
steps:
|
||||
- name: debug
|
||||
run: |
|
||||
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
|
||||
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Generate ghrc.io tag
|
||||
id: ghcr-io-tag
|
||||
run: |
|
||||
ecr_image="${{ matrix.docker-image }}"
|
||||
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
|
||||
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
|
||||
|
||||
- name: Save as tarball
|
||||
run: |
|
||||
docker_image_tag=${{ matrix.docker-image }}
|
||||
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
|
||||
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
|
||||
ref_name=${{ github.event.workflow_run.head_branch }}
|
||||
if [[ $ref_name =~ "release/" ]]; then
|
||||
ref_suffix="release"
|
||||
elif [[ $ref_name == "main" ]]; then
|
||||
ref_suffix="main"
|
||||
else
|
||||
# TODO: Remove below
|
||||
ref_suffix="main"
|
||||
# echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
|
||||
fi
|
||||
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
|
||||
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
|
||||
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
|
||||
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar
|
||||
@ -142,7 +142,6 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
|
||||
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
|
||||
auto batch_sizes_t = _batch_sizes.contiguous();
|
||||
checkLongTensor(batch_sizes_t);
|
||||
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
|
||||
|
||||
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
|
||||
int64_t max_batch_size = batch_sizes[0];
|
||||
|
||||
@ -669,12 +669,9 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
// On non CK system(w/ ROCm), make sure use_fast_path is false
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
@ -683,11 +680,7 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
|
||||
@ -1426,9 +1426,6 @@ static at::Tensor _fp8_convolution_onednn_ref(
|
||||
w_scales_new_shape[0] = -1;
|
||||
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
|
||||
auto output_padding = std::vector<int64_t>(kSpatialDim, 0);
|
||||
if (bias.has_value()){
|
||||
bias = bias.value().to(at::kFloat);
|
||||
}
|
||||
auto y_f32 = at::convolution(
|
||||
dqx, dqw, bias, stride.vec(), padding.vec(), dilation.vec(), /* transposed */false, output_padding, groups
|
||||
);
|
||||
|
||||
@ -47,7 +47,6 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/for_each.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/gather.h>
|
||||
|
||||
@ -50,7 +50,7 @@ nfnet_l0,pass,7
|
||||
|
||||
|
||||
|
||||
repvgg_a2,pass,7
|
||||
repvgg_a2,fail_accuracy,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -14,10 +14,6 @@ Utils
|
||||
|
||||
sdpa_kernel
|
||||
SDPBackend
|
||||
register_flash_attention_impl
|
||||
activate_flash_attention_impl
|
||||
list_flash_attention_impls
|
||||
current_flash_attention_impl
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
|
||||
pip install ninja
|
||||
|
||||
# Install onnx
|
||||
pip install -e "$tp2_dir/onnx"
|
||||
pip install --no-use-pep517 -e "$tp2_dir/onnx"
|
||||
|
||||
# Install caffe2 and pytorch
|
||||
pip install -r "$top_dir/caffe2/requirements.txt"
|
||||
|
||||
@ -140,11 +140,6 @@ static void initDeviceStreamState(DeviceIndex device_index) {
|
||||
static void initOpenRegStreamsOnce() {
|
||||
c10::call_once(init_flag, initGlobalStreamState);
|
||||
|
||||
for (const auto i : c10::irange(num_devices)) {
|
||||
c10::call_once(
|
||||
device_flags[i], initDeviceStreamState, static_cast<DeviceIndex>(i));
|
||||
}
|
||||
|
||||
if (current_streams) {
|
||||
return;
|
||||
}
|
||||
@ -207,6 +202,8 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
|
||||
if (device_index == -1) {
|
||||
device_index = current_device();
|
||||
}
|
||||
c10::call_once(
|
||||
device_flags[device_index], initDeviceStreamState, device_index);
|
||||
auto pri_idx =
|
||||
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
|
||||
const auto idx = get_idx(priority_counters[device_index][pri_idx]);
|
||||
|
||||
@ -180,47 +180,6 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
||||
del model
|
||||
del optim
|
||||
|
||||
def _test_tracker_multihandler_hook(self):
|
||||
"""Should run without KeyError."""
|
||||
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.norm1 = nn.RMSNorm(dim)
|
||||
self.output1 = nn.Linear(dim, dim)
|
||||
self.norm2 = nn.RMSNorm(dim)
|
||||
self.output2 = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.norm1(x)
|
||||
x = self.output1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.output2(x)
|
||||
return x
|
||||
|
||||
gc.collect()
|
||||
torch.manual_seed(42)
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
|
||||
with torch.device(dev):
|
||||
model = TestModule(128)
|
||||
|
||||
mesh = init_device_mesh(dev.type, (self.world_size,))
|
||||
fully_shard([model.norm1, model.output1], mesh=mesh)
|
||||
fully_shard([model.norm2, model.output2], mesh=mesh)
|
||||
fully_shard(model, mesh=mesh)
|
||||
|
||||
fmt = FSDPMemTracker(model)
|
||||
|
||||
with fmt:
|
||||
inp = torch.randn(16, 128, device=dev)
|
||||
y = model(inp)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
del inp
|
||||
del model
|
||||
|
||||
|
||||
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
|
||||
@property
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -372,7 +371,6 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
|
||||
# is producing a joint graph with backward region missing
|
||||
@unittest.expectedFailure
|
||||
def test_strict_export_parallelize_module_with_dtensor_input(self):
|
||||
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ import torch._functorch.config
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from functorch.compile import default_partition, min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
|
||||
)
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
@ -281,7 +281,14 @@ class ActivationCheckpointingViaTagsTests(
|
||||
|
||||
run(export_compiler)
|
||||
|
||||
def test_tags_function(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -297,11 +304,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_function_via_global_checkpoint(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -316,17 +334,28 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_function_with_kwargs(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_with_kwargs(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
|
||||
gn, torch.sin(x), y, use_reentrant=False
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
@ -336,11 +365,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_sequential_layers(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_sequential_layers(self, device, partition_fn):
|
||||
def gn(x):
|
||||
x = x.cos()
|
||||
for _ in range(3):
|
||||
@ -361,11 +401,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
freqs=[2, 18],
|
||||
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_multiple_checkpoints(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_multiple_checkpoints(self, device, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -383,11 +434,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=6, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_module(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_module(self, device, partition_fn):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -411,11 +473,22 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_tags_decomps(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_decomps(self, device, partition_fn):
|
||||
# Ensures that tags are passed on through decompositions as well
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -443,6 +516,7 @@ class ActivationCheckpointingViaTagsTests(
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=lambda: import_module(
|
||||
"torch._inductor.compile_fx"
|
||||
).select_decomp_table(),
|
||||
@ -702,7 +776,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
|
||||
def context_fn_must_recompute_mm():
|
||||
must_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -723,9 +804,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
),
|
||||
)
|
||||
|
||||
def _test(context_fn, bw_compiler):
|
||||
def _test(context_fn, bw_compiler, partition_fn):
|
||||
def gn(x):
|
||||
return torch.sigmoid(torch.matmul(x, x))
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
@ -739,14 +820,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
fw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freq=1,
|
||||
freq=2,
|
||||
op=torch.ops.aten.mm.default,
|
||||
)
|
||||
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@ -754,17 +835,19 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
context_fn=context_fn_must_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
|
||||
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
_test(
|
||||
context_fn=context_fn_no_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=2, # 2 bwd mm ops per fwd matmul
|
||||
freq=4, # 2 bwd mm ops per fwd matmul
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
def test_sac_with_partial_context_fn(self):
|
||||
@ -801,7 +884,16 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(
|
||||
self, device, partition_fn
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -841,15 +933,22 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
|
||||
self, device
|
||||
self, device, partition_fn
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
@ -889,7 +988,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
@ -897,7 +996,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
|
||||
# Copy of the above test, but make sure that having a triton kernel in the
|
||||
# region does not error.
|
||||
def add_one(x):
|
||||
@ -957,14 +1063,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1007,14 +1120,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
|
||||
def _get_custom_policy(meta):
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1072,14 +1192,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn(no_recompute_list):
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
@ -1118,14 +1245,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1163,14 +1297,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_list_ops(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
# recompute everything
|
||||
no_recompute_list = []
|
||||
@ -1206,7 +1347,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1217,7 +1358,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
"requires TorchDispatchMode + torch.compile work to complete"
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1257,7 +1405,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1265,7 +1413,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
def test_compile_selective_checkpoint_random_op(self, device):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
|
||||
for preserve_rng_state in [True, False]:
|
||||
|
||||
def selective_checkpointing_context_fn():
|
||||
@ -1312,7 +1467,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
|
||||
@ -1324,7 +1479,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
def test_compile_selective_checkpoint_invalid_context(self):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y)) * y
|
||||
|
||||
@ -1353,7 +1515,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "must generate a tuple of two `TorchDispatchMode`s"
|
||||
@ -1362,7 +1524,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
def test_compile_selective_checkpoint_parametrization(self):
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
|
||||
def sac_policy():
|
||||
def _recomp_policy():
|
||||
def _custom_policy(ctx, func, *args, **kwargs):
|
||||
@ -1425,7 +1594,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
bw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freqs=[
|
||||
2, # 1 from mul recompute, 1 from mul backward
|
||||
# 1 from mul recompute, 1 from mul backward
|
||||
# w/o CSE, we have one extra mul
|
||||
3 if partition_fn is default_partition else 2,
|
||||
1,
|
||||
],
|
||||
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
||||
@ -1434,7 +1605,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
model = MLPModule()
|
||||
|
||||
@ -2363,34 +2363,6 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(same(output, expected))
|
||||
assert cnt.frame_count == 1
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 13), "math.fma introduced in python 3.13")
|
||||
def test_math_fma(self):
|
||||
def fma_func(a, b, c):
|
||||
return math.fma(a, b, c)
|
||||
|
||||
# Test with scalar constants (constant folding path)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
cfma_scalars = torch._dynamo.optimize_assert(cnt)(fma_func)
|
||||
|
||||
assert cnt.frame_count == 0
|
||||
expected = fma_func(2.0, 3.0, 4.0)
|
||||
output = cfma_scalars(2.0, 3.0, 4.0)
|
||||
self.assertEqual(output, expected)
|
||||
assert cnt.frame_count == 0
|
||||
|
||||
# Test with tensors (Inductor path)
|
||||
cnt2 = torch._dynamo.testing.CompileCounter()
|
||||
cfma_tensors = torch._dynamo.optimize_assert(cnt2)(fma_func)
|
||||
|
||||
assert cnt2.frame_count == 0
|
||||
x = torch.tensor(2.0)
|
||||
y = torch.tensor(3.0)
|
||||
z = torch.tensor(4.0)
|
||||
expected_tensors = x * y + z
|
||||
output_tensors = cfma_tensors(x, y, z)
|
||||
torch.testing.assert_close(output_tensors, expected_tensors)
|
||||
assert cnt2.frame_count == 1
|
||||
|
||||
@make_test
|
||||
def test_numpy_meshgrid(x, y):
|
||||
r1, r2 = np.meshgrid(x.numpy(), y.numpy())
|
||||
|
||||
@ -335,59 +335,6 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@requires_multigpu()
|
||||
def test_new_event_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_event
|
||||
|
||||
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
e0_ind = new_event()
|
||||
with torch.Stream(device="cuda:1"):
|
||||
get_external_object_by_index(e0_ind).record()
|
||||
e1_ind = new_event()
|
||||
self.assertNotEqual(e0_ind, e1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(e0_ind),
|
||||
get_external_object_by_index(e1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=event_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_new_stream_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_stream
|
||||
|
||||
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
s0_ind = new_stream()
|
||||
s1_ind = new_stream()
|
||||
self.assertNotEqual(s0_ind, s1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(s0_ind),
|
||||
get_external_object_by_index(s1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=stream_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
@ -576,23 +523,6 @@ class <lambda>(torch.nn.Module):
|
||||
torch.accelerator.set_stream(original_stream)
|
||||
reset_user_object_tracking()
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck_wait_record_stream(self):
|
||||
from torch._dynamo.variables.streams import wait_stream
|
||||
from torch.library import opcheck
|
||||
|
||||
s0 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s2 = torch.Stream()
|
||||
store_user_object_weakrefs(s0, s1, s2)
|
||||
|
||||
sample_inputs = [
|
||||
(0, 1),
|
||||
(2, 0),
|
||||
]
|
||||
for args in sample_inputs:
|
||||
opcheck(wait_stream, args)
|
||||
|
||||
@requires_cuda
|
||||
def test_inductor_lowering(self):
|
||||
with patch("torch._inductor.config.implicit_fallbacks", False):
|
||||
|
||||
@ -331,12 +331,7 @@ class TestDynamismExpression(TestCase):
|
||||
return torch.ops.aten.slice.Tensor(*args)
|
||||
|
||||
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
|
||||
dynamic_shapes = (
|
||||
{0: Dim("dim")},
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
|
||||
torch.export.export(
|
||||
Slice(),
|
||||
inp,
|
||||
@ -590,6 +585,7 @@ class TestExport(TestCase):
|
||||
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
|
||||
self._test_export_same_as_eager(f, inp)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
@skipIfCrossRef
|
||||
def test_custom_tag_metadata_re_export(self):
|
||||
class Foo(torch.nn.Module):
|
||||
@ -1026,6 +1022,7 @@ graph():
|
||||
dynamic_shapes = {"x": (dim0_x, dim1_x)}
|
||||
export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_tensor_computation(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -1361,6 +1358,7 @@ def forward(self, primals, tangents):
|
||||
# instead of the scripted function, so we get x.sin()
|
||||
self.assertEqual(res, x.sin())
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_tensor_computation_2(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -1379,6 +1377,7 @@ graph():
|
||||
return (x,)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_tensor_computation_3(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -1397,6 +1396,7 @@ graph():
|
||||
return (5,)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_tensor_computation_4(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -1939,6 +1939,7 @@ graph():
|
||||
for vr_upper in vr_upper_bounds:
|
||||
self.assertEqual(vr_upper, 1)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_detect_leak_strict(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -2687,6 +2688,7 @@ class GraphModule(torch.nn.Module):
|
||||
gm = export(m, (torch.rand(64, 64),))
|
||||
torch.export.unflatten(gm)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_unflatten_closure(self):
|
||||
class Dummy(torch.nn.Module):
|
||||
def forward(self, fn, x):
|
||||
@ -4192,6 +4194,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||
if str(sym) in ["u0", "s0"]:
|
||||
self.assertEqual(vr.lower, 1)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_duplicate_modules_with_non_persistent_buffers(self):
|
||||
class FooWithBuf(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -4835,6 +4838,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
|
||||
table.materialize()
|
||||
self.assertFalse(torch.ops.mylib.foo123.default in table)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_if_post_autograd_op_preserved(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -5538,11 +5542,21 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
|
||||
w = Wrapped()
|
||||
|
||||
compiled = export(w, args, dynamic_shapes=({0: batch}, {0: batch}))
|
||||
expected = w(*args)
|
||||
mod = compiled.module()
|
||||
got = mod(*args)
|
||||
self.assertTrue(torch.allclose(expected, got))
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Detected mismatch between the structure of `inputs` and `dynamic_shapes`"
|
||||
": `inputs` has 2 elements, but `dynamic_shapes` has 1 elements",
|
||||
):
|
||||
export(w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})})
|
||||
else:
|
||||
compiled = export(
|
||||
w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})}
|
||||
)
|
||||
expected = w(*args)
|
||||
mod = compiled.module()
|
||||
got = mod(*args)
|
||||
self.assertTrue(torch.allclose(expected, got))
|
||||
|
||||
def test_dynamic_shapes_builder_basic(self):
|
||||
class M(torch.nn.Module):
|
||||
@ -7223,6 +7237,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
@testing.expectedFailureSerDer # we don't save placeholder metadata
|
||||
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_linear_conv(self):
|
||||
strict = True
|
||||
|
||||
@ -8821,6 +8836,7 @@ def forward(self, x):
|
||||
)
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_automatic_constrain_size(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -8932,6 +8948,7 @@ def forward(self, x):
|
||||
):
|
||||
ep.graph_module.while_loop_body_graph_0(torch.tensor([5]), torch.zeros(1))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_constrain_decomp(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -9570,6 +9587,7 @@ def forward(self, b_a_buffer, x):
|
||||
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_export_associative_scan_lifted_buffers(self):
|
||||
if "cpp_runtime_nonstrict" in self.id():
|
||||
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
|
||||
@ -9660,6 +9678,7 @@ def forward(self, b_a_buffer, x):
|
||||
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_no_check_is_size_error(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -9813,6 +9832,7 @@ def forward(self, b_a_buffer, x):
|
||||
self.assertEqual(len(ep.graph_signature.input_specs), 4)
|
||||
self.assertTrue(torch.allclose(ep.module()(*inp), transform.module()(*inp)))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_tensor_attribute_zero_args(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, value):
|
||||
@ -9826,6 +9846,7 @@ def forward(self, b_a_buffer, x):
|
||||
ep = export(m, ())
|
||||
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_preserve_shape_dynamism_for_unused_inputs(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp3,
|
||||
@ -9995,6 +10016,7 @@ def forward(self, p_lin_weight, p_lin_bias, x):
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_export_decomp_torture_case_2(self):
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -10130,6 +10152,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||
# expected 4, but got 7
|
||||
ep_v2.module()(*test_inp)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_constant_output(self):
|
||||
class ModuleConstant(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -10214,6 +10237,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||
# expected >= 3, but got 2
|
||||
ep.module()(*test_inp)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_nested_module(self):
|
||||
class M1(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -10251,6 +10275,7 @@ graph():
|
||||
unflattened = unflatten(ep)
|
||||
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_nested_module_with_init_buffer(self):
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -10378,6 +10403,7 @@ graph():
|
||||
ep = export(m, sample_inputs)
|
||||
self.assertEqual(ep.module()(*sample_inputs), m(*sample_inputs))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_lazy_module_kwargs(self):
|
||||
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
|
||||
def initialize_parameters(self, *args, **kwargs):
|
||||
@ -12251,6 +12277,7 @@ graph():
|
||||
ep.module()(x)
|
||||
|
||||
@testing.expectedFailureCppRuntime
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_symint_input_basic(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -12970,6 +12997,7 @@ def forward(self, c_submod_params, x):
|
||||
ufm = torch.export.unflatten(ep)
|
||||
self.assertTrue(torch.allclose(ufm(*inp), epm(*inp)))
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_unflatten_multiple_graphs_shared_submodule(self):
|
||||
class N(torch.nn.Module):
|
||||
def forward(self, x, b):
|
||||
@ -14021,6 +14049,7 @@ def forward(self, x):
|
||||
return (foo_functional,)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_placeholder_naming_order(self):
|
||||
# See https://github.com/pytorch/pytorch/issues/143732
|
||||
|
||||
@ -14072,6 +14101,7 @@ def forward(self, x):
|
||||
).run_decompositions()
|
||||
ep.module()(torch.ones(4, 4), **kwargs)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_placeholder_naming_order_variadic(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, a, b, c, **kwargs):
|
||||
@ -14096,6 +14126,7 @@ def forward(self, x):
|
||||
):
|
||||
export(Foo(), (torch.randn(4, 4),), strict=False)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_placeholder_naming_collisions(self):
|
||||
# test collisions between nested user inputs
|
||||
class Foo(torch.nn.Module):
|
||||
@ -14168,6 +14199,7 @@ def forward(self, x):
|
||||
self.assertEqual(expected_names_and_ops, real_names_and_ops)
|
||||
|
||||
@skipIfCrossRef # Dynamo changes the order of ops under Torch function modes
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_placeholder_naming_collisions_hoo_subgraphs(self):
|
||||
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
|
||||
class Foo(torch.nn.Module):
|
||||
@ -14245,6 +14277,7 @@ def forward(self, x):
|
||||
]
|
||||
self.assertEqual(expected_getattr_names, real_getattr_names)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_constant_input_naming(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y, div="floor"):
|
||||
@ -14936,6 +14969,7 @@ graph():
|
||||
]
|
||||
self.assertEqual(len(repeat_nodes), 0)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_checks_to_constrain_range(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -15270,6 +15304,7 @@ graph():
|
||||
Block(torch.randn(4, 4), torch.randn(4, 4))
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_enum_str(self):
|
||||
class TensorDim(str, enum.Enum):
|
||||
DDP = "ddp"
|
||||
@ -15431,6 +15466,7 @@ def forward(self, x):
|
||||
return (getitem_3, cos_1)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_run_decompositions_keep_metadata(self):
|
||||
"""Make sure the metadata is kept after exported program run_decompositions."""
|
||||
|
||||
@ -15460,6 +15496,7 @@ def forward(self, x):
|
||||
for node in decomposed_program.graph.nodes:
|
||||
self.assertEqual(node.meta["custom"]["my_field"], "dummy")
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_run_decompositions_keep_tensor_constant_metadata(self):
|
||||
"""Make sure the metadata of tensor constants are kept after run_decompositions."""
|
||||
|
||||
@ -16091,6 +16128,7 @@ def forward(self, x):
|
||||
|
||||
@testing.expectedFailureSerDer # T195866111
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_hints_wrapper(self):
|
||||
strict = True
|
||||
|
||||
@ -16665,6 +16703,7 @@ def forward(self, args_0):
|
||||
return (abs_1,)""",
|
||||
)
|
||||
|
||||
@testing.expectedFailureStrictV2
|
||||
def test_sdpa_gqa(self):
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
@ -17499,105 +17538,6 @@ def forward(self, x):
|
||||
exported_param_names = [name for name, _ in gm.named_parameters()]
|
||||
self.assertEqual(original_param_names, exported_param_names)
|
||||
|
||||
def test_export_compiled_model_with_nested_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, data_batch):
|
||||
return data_batch["a1"] + data_batch["a2"]
|
||||
|
||||
m = M()
|
||||
compiled_m = torch.compile(m)
|
||||
example_args = (
|
||||
{
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
},
|
||||
)
|
||||
dynamic_shapes = (
|
||||
{
|
||||
"a1": {0: Dim.DYNAMIC},
|
||||
"a2": {0: Dim.DYNAMIC},
|
||||
},
|
||||
)
|
||||
ep = export(
|
||||
compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(*example_args), compiled_m(*example_args))
|
||||
|
||||
def test_export_model_with_nested_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, data_batch):
|
||||
return data_batch["a1"] + data_batch["a2"]
|
||||
|
||||
m = M()
|
||||
example_args = (
|
||||
{
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
},
|
||||
)
|
||||
B = torch.export.Dim("batch", min=1, max=65536)
|
||||
dynamic_shapes = (
|
||||
{
|
||||
"a1": {0: B},
|
||||
"a2": {0: B},
|
||||
},
|
||||
)
|
||||
ep = export(m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(*example_args), m(*example_args))
|
||||
|
||||
def test_export_compiled_model_with_kwargs_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a1, a2):
|
||||
return a1 + a2
|
||||
|
||||
m = M()
|
||||
compiled_m = torch.compile(m)
|
||||
example_args = ()
|
||||
example_kwargs = {
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
}
|
||||
dynamic_shapes = {
|
||||
"a1": {0: Dim.DYNAMIC},
|
||||
"a2": {0: Dim.DYNAMIC},
|
||||
}
|
||||
ep = export(
|
||||
compiled_m,
|
||||
example_args,
|
||||
kwargs=example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(**example_kwargs), compiled_m(**example_kwargs))
|
||||
|
||||
def test_export_model_with_kwargs_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a1, a2):
|
||||
return a1 + a2
|
||||
|
||||
m = M()
|
||||
example_args = ()
|
||||
example_kwargs = {
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
}
|
||||
dynamic_shapes = {
|
||||
"a1": {0: Dim.DYNAMIC},
|
||||
"a2": {0: Dim.DYNAMIC},
|
||||
}
|
||||
ep = export(
|
||||
m,
|
||||
example_args,
|
||||
kwargs=example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(**example_kwargs), m(**example_kwargs))
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||
class TestExportCustomClass(TorchTestCase):
|
||||
|
||||
@ -15,7 +15,7 @@ test_classes = {}
|
||||
|
||||
def mocked_strict_export_v2(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
with config.patch(use_new_tracer_experimental=True):
|
||||
with config.patch(use_legacy_dynamo_graph_capture=False):
|
||||
if "strict" in kwargs:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
|
||||
return grad_output * x, grad_output * x
|
||||
|
||||
def f(a, b):
|
||||
return FwBwMutation.apply(a, b)
|
||||
return FwBwMutation.apply(a, b).sin_().clone()
|
||||
|
||||
inps = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2):
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
|
||||
return (mul, add)""",
|
||||
clone = torch.ops.aten.clone.default(mul)
|
||||
sin_ = torch.ops.aten.sin_.default(mul); mul = None
|
||||
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
|
||||
return (clone_1, add, clone)""",
|
||||
)
|
||||
|
||||
# important bit: there is 1 mutation in the bw
|
||||
self.assertExpectedInline(
|
||||
bw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, add, tangents_1):
|
||||
def forward(self, add, clone, tangents_1):
|
||||
cos = torch.ops.aten.cos.default(clone); clone = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
|
||||
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
|
||||
return (mul_1, None)""",
|
||||
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
|
||||
return (mul_2, None)""",
|
||||
)
|
||||
|
||||
def test_fw_bw_mutation_no_functionalization2(self):
|
||||
|
||||
@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
self.assertEqual(len(mm_nodes), 4)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
|
||||
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)
|
||||
|
||||
@ -4101,53 +4101,6 @@ if HAS_CUDA_AND_TRITON:
|
||||
compiled_out = compiled_foo(x)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
# Use autotune_at_compile_time=True to test standalone_compile
|
||||
@parametrize("autotune_at_compile_time", [True, False])
|
||||
@config.patch("graph_partition", True)
|
||||
def test_graph_partition_kernel_reuse(self, autotune_at_compile_time):
|
||||
def foo(x):
|
||||
# partition 1
|
||||
x1 = x @ x
|
||||
y1 = x1 + 1
|
||||
z_cpu = y1.cpu() + 1
|
||||
# partition 2
|
||||
# partition 2 should reuse the fused triton kernel generated
|
||||
# in partition 1
|
||||
x2 = z_cpu.to("cuda") @ z_cpu.to("cuda")
|
||||
y2 = x2 + 1
|
||||
return y1, y2
|
||||
|
||||
with config.patch(
|
||||
"triton.autotune_at_compile_time", autotune_at_compile_time
|
||||
):
|
||||
compiled_foo = torch.compile(foo)
|
||||
x = torch.randn((20, 20), device="cuda")
|
||||
eager_out = foo(x)
|
||||
compiled_out, code = run_and_get_code(compiled_foo, x)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
if autotune_at_compile_time:
|
||||
# auto-tuning block should only appear once. We generate auto-tuning code
|
||||
# for all the kernels no matter if they are defined in the main graph or
|
||||
# subgraph, to avoid the overhead of executing multiple auto-tuning code blocks.
|
||||
FileCheck().check_count(
|
||||
"Compile-time auto-tuning block", 1, exactly=True
|
||||
).run(code[0])
|
||||
# triton_poi_fused_add_ should appear twice, first in the auto-tuning block,
|
||||
# and then in the main code block
|
||||
FileCheck().check_count(
|
||||
"def triton_poi_fused_add_", 2, exactly=True
|
||||
).run(code[0])
|
||||
# cpu kernel definition should only appence once, not in the auto-tuning block
|
||||
FileCheck().check_count(
|
||||
"cpp_fused__to_copy_add_1 = ", 1, exactly=True
|
||||
).run(code[0])
|
||||
else:
|
||||
# triton_poi_fused_add_ should appear once, because of kernel reuse
|
||||
FileCheck().check_count(
|
||||
"def triton_poi_fused_add_", 1, exactly=True
|
||||
).run(code[0])
|
||||
|
||||
def test_meta_tensor(self):
|
||||
def foobar(x, y):
|
||||
return x * 2, y * 3
|
||||
|
||||
@ -4,9 +4,8 @@ from functools import partial
|
||||
from unittest import skipIf
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.ir import Pointwise
|
||||
from torch._inductor.lowering import make_fallback, make_pointwise, register_lowering
|
||||
from torch._inductor.lowering import make_pointwise, register_lowering
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.virtualized import ops
|
||||
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
|
||||
@ -238,17 +237,6 @@ class TestCustomLowering(InductorTestCase):
|
||||
out2 = fn_opt(a, b)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
@config.patch(joint_graph_constant_folding=False)
|
||||
def test_constant_creation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + torch.tensor(1)
|
||||
|
||||
make_fallback(torch.ops.aten.lift_fresh_copy.default)
|
||||
self.assertTrue(
|
||||
torch.allclose(torch.compile(M())(torch.ones(3)), torch.ones(3) + 1)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
@ -492,36 +492,6 @@ class PackedSequenceTest(TestCase):
|
||||
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
|
||||
)
|
||||
|
||||
def test_empty_packed_sequence(self):
|
||||
"""
|
||||
Regression test for https://github.com/pytorch/pytorch/issues/149622
|
||||
Tests that pad_packed_sequence and unpack_sequence handle empty tensors
|
||||
without segmentation fault (CVE-2025-2998, CVE-2025-2999)
|
||||
"""
|
||||
# Test case 1: pad_packed_sequence with empty tensors
|
||||
# Previously caused segmentation fault
|
||||
empty_data = torch.randn(0, 5)
|
||||
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
|
||||
empty_packed = rnn_utils.PackedSequence(
|
||||
empty_data, empty_batch_sizes, None, None
|
||||
)
|
||||
|
||||
# Should not crash - either return empty result or raise informative error
|
||||
with self.assertRaises(RuntimeError):
|
||||
rnn_utils.pad_packed_sequence(empty_packed, batch_first=True)
|
||||
|
||||
# Test case 2: unpack_sequence with empty tensors
|
||||
# Previously caused segmentation fault
|
||||
empty_data = torch.tensor([])
|
||||
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
|
||||
packed = rnn_utils.PackedSequence(
|
||||
data=empty_data, batch_sizes=empty_batch_sizes
|
||||
)
|
||||
|
||||
# Should not crash - either return empty list or raise informative error
|
||||
with self.assertRaises(RuntimeError):
|
||||
rnn_utils.unpack_sequence(packed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1001,10 +1001,24 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
|
||||
import inspect
|
||||
|
||||
if isinstance(mod, torch.nn.Module):
|
||||
if len(mod._forward_pre_hooks) == 0 and len(mod._forward_hooks) == 0:
|
||||
# Mirrored from NNModuleVariable.call_function:
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/nn_module.py#L1035
|
||||
if (
|
||||
len(mod._forward_pre_hooks) == 0
|
||||
and len(mod._forward_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_forward_pre_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_forward_hooks) == 0
|
||||
and len(mod._backward_pre_hooks) == 0
|
||||
and len(mod._backward_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_backward_pre_hooks) == 0
|
||||
and len(torch.nn.modules.module._global_backward_hooks) == 0
|
||||
):
|
||||
mod = mod.forward
|
||||
elif isinstance(mod, torch.fx.GraphModule):
|
||||
mod = mod._call_impl
|
||||
else:
|
||||
mod = mod.__call__
|
||||
|
||||
if hasattr(mod, "__self__"):
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return mod.__func__, mod.__self__
|
||||
|
||||
@ -637,7 +637,7 @@ def dynamo_graph_capture_for_export(
|
||||
pyt.in_shuffle_graph,
|
||||
pyt.out_shuffle_graph,
|
||||
tree_leaf_names,
|
||||
pyt.root,
|
||||
graph_module if isinstance(pyt.root, torch.nn.Module) else pyt.root,
|
||||
) # type: ignore[attr-defined]
|
||||
normalize_graph_module(graph_module)
|
||||
if pyt.root is not None:
|
||||
@ -648,6 +648,10 @@ def dynamo_graph_capture_for_export(
|
||||
graph_module._non_persistent_buffers_set = (
|
||||
pyt.root._non_persistent_buffers_set.copy()
|
||||
)
|
||||
annotations = torch.nn.Module.__dict__.get("__annotations__", None)
|
||||
for name, value in pyt.root.__dict__.items():
|
||||
if annotations and name not in annotations:
|
||||
graph_module.__dict__[name] = value
|
||||
graph_module._in_spec = pyt.in_spec
|
||||
graph_module._out_spec = pyt.out_spec
|
||||
assert not hasattr(graph_module, "_in_shuffle_graph")
|
||||
|
||||
@ -2320,8 +2320,6 @@ if sys.version_info >= (3, 11):
|
||||
torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable
|
||||
torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
torch_c_binding_in_graph_functions["math.fma"] = TorchInGraphFunctionVariable
|
||||
|
||||
# In graph functions (including constant folding) that are not C bindings
|
||||
# NOTE: [Cacheability of in-graph torch functions]
|
||||
|
||||
@ -10,10 +10,7 @@ from torch.fx import has_side_effect, Proxy
|
||||
from .. import graph_break_hints
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import TYPE_CHECKING, unimplemented
|
||||
from ..graph_bytecode_inputs import (
|
||||
get_external_object_by_index,
|
||||
register_graph_created_object,
|
||||
)
|
||||
from ..graph_bytecode_inputs import get_external_object_by_index
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import FxTracebackAnnotateVariable
|
||||
@ -31,26 +28,6 @@ from torch._library.custom_ops import custom_op
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
def new_event(*args: Any, **kwargs: Any) -> int:
|
||||
event = torch.Event(*args, **kwargs)
|
||||
return register_graph_created_object(
|
||||
event,
|
||||
EventVariable.make_construct_in_graph_event_fn(
|
||||
TupleVariable([]), ConstDictVariable({})
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
|
||||
stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
|
||||
return register_graph_created_object(
|
||||
stream,
|
||||
StreamVariable.make_construct_in_graph_stream_fn(
|
||||
TupleVariable([]), ConstDictVariable({})
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_stream_by_index(index: int) -> torch.Stream:
|
||||
stream = get_external_object_by_index(index)
|
||||
assert isinstance(stream, torch.Stream), (
|
||||
@ -138,24 +115,6 @@ def _(
|
||||
has_side_effect(torch.ops.streams.wait_event.default)
|
||||
|
||||
|
||||
@custom_op("streams::wait_stream", mutates_args=())
|
||||
def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
|
||||
waiting = _get_stream_by_index(waiting_stream_index)
|
||||
waited_on = _get_stream_by_index(waited_on_stream_index)
|
||||
waiting.wait_stream(waited_on)
|
||||
|
||||
|
||||
@wait_stream.register_fake
|
||||
def _(
|
||||
event_index: int,
|
||||
stream_index: int,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
has_side_effect(torch.ops.streams.wait_stream.default)
|
||||
|
||||
|
||||
class SymbolicStreamState:
|
||||
"""Track the currently entered stream if any"""
|
||||
|
||||
|
||||
@ -603,21 +603,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
VariableTracker.build(tx, polyfills.radians), args, kwargs
|
||||
)
|
||||
|
||||
if hasattr(math, "fma"): # Python 3.13+
|
||||
|
||||
@register(math.fma)
|
||||
def handle_fma(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
if len(args) != 3 or kwargs:
|
||||
return None
|
||||
|
||||
if all(isinstance(arg, variables.TensorVariable) for arg in args):
|
||||
x, y, z = args
|
||||
addcmul_fn = TorchInGraphFunctionVariable(torch.addcmul)
|
||||
return addcmul_fn.call_function(tx, [z, x, y], {})
|
||||
|
||||
# Use math.fma if constants
|
||||
return None
|
||||
|
||||
@register(torch.is_inference_mode_enabled)
|
||||
def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"):
|
||||
unimplemented(
|
||||
|
||||
@ -33,6 +33,9 @@ error_on_lifted_constant_tensors = True
|
||||
# being ready to handle auto_functionalized_v2.
|
||||
enable_auto_functionalized_v2_for_export = not is_fbcode()
|
||||
|
||||
use_legacy_dynamo_graph_capture = True
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_proxy_tensor_disable_update_tensor_tracker,
|
||||
get_proxy_mode,
|
||||
maybe_disable_thunkify,
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
@ -295,6 +296,10 @@ def create_joint(
|
||||
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
|
||||
fn, primals
|
||||
)
|
||||
mode = get_proxy_mode()
|
||||
assert mode is not None
|
||||
for node in mode.tracer.graph.nodes:
|
||||
node.meta["partitioner_tag"] = "is_forward"
|
||||
|
||||
# TODO: I think this hook can also be eliminated now
|
||||
if joint_fn_handle and joint_fn_handle.post_forward:
|
||||
|
||||
@ -51,6 +51,7 @@ from ._activation_checkpointing.knapsack import (
|
||||
)
|
||||
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
||||
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
|
||||
from ._aot_autograd.functional_utils import assert_functional_graph
|
||||
from ._aot_autograd.logging_utils import get_aot_graph_name
|
||||
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
|
||||
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
|
||||
@ -297,6 +298,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_backward"
|
||||
|
||||
|
||||
def _has_tag_is_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_forward"
|
||||
|
||||
|
||||
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
|
||||
|
||||
@ -1021,105 +1026,95 @@ def default_partition(
|
||||
Returns:
|
||||
Returns the generated forward and backward Fx graph modules.
|
||||
"""
|
||||
if has_recomputable_ops(joint_module):
|
||||
return min_cut_rematerialization_partition(
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
# Respect the original placement of ops rather than rely on dataflow.
|
||||
forward_nodes = []
|
||||
last_node = None
|
||||
for node in joint_module.graph.nodes:
|
||||
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
|
||||
last_node = node
|
||||
assert last_node is not None
|
||||
for node in joint_module.graph.nodes:
|
||||
if not _is_tangent(node):
|
||||
forward_nodes.append(node)
|
||||
if node is last_node:
|
||||
break
|
||||
forward_node_names = OrderedSet(
|
||||
node.name for node in forward_only_graph.nodes if node.op != "output"
|
||||
node.name for node in forward_nodes if node.op != "output"
|
||||
)
|
||||
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
assert_functional_graph(joint_module.graph)
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
|
||||
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
def is_mutated_later_in_fw(node):
|
||||
if _has_tag_is_backward(node):
|
||||
return False
|
||||
tensor_arg_aliases = [
|
||||
x
|
||||
for x in node.args
|
||||
if isinstance(x, fx.Node)
|
||||
and "val" in x.meta
|
||||
and isinstance(x.meta["val"], torch.Tensor)
|
||||
]
|
||||
while len(tensor_arg_aliases) > 0:
|
||||
a = tensor_arg_aliases.pop()
|
||||
for u in a.users:
|
||||
if not isinstance(u.target, torch._ops.OpOverload):
|
||||
continue
|
||||
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
|
||||
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
|
||||
if (
|
||||
# one of the args was mutated
|
||||
u.target._schema.is_mutable
|
||||
# and the mutation happens "later"
|
||||
and order[u] > order[node]
|
||||
# and the mutation happened during the forward
|
||||
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
|
||||
):
|
||||
for idx, alias_info in enumerate(u.target._schema.arguments):
|
||||
if alias_info.is_write and u.args[idx] is a:
|
||||
return True
|
||||
elif u.target.is_view:
|
||||
tensor_arg_aliases.append(u)
|
||||
return False
|
||||
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.name not in forward_node_names:
|
||||
# if a node isn't "required" to be in the forward, but any of its arguments
|
||||
# are later mutated in the forward, then it must have been run in the forward
|
||||
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
|
||||
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
|
||||
if is_mutated_later_in_fw(node):
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_sym_node(node):
|
||||
# Symints must be kept separate from tensors so that PythonFunction only calls
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
saved_sym_nodes.append(node)
|
||||
elif (
|
||||
continue
|
||||
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if node.is_impure(impure_random=False) and node.op not in (
|
||||
"placeholder",
|
||||
"output",
|
||||
):
|
||||
# See is_impure in torch/fx/node.py
|
||||
assert not graph_has_recomputable_ops, (
|
||||
"Trying to apply AC on a graph with impure op",
|
||||
node,
|
||||
node.target,
|
||||
)
|
||||
saved_values.append(node)
|
||||
continue
|
||||
backward_usages = [n for n in node.users if n.name not in forward_node_names]
|
||||
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
continue
|
||||
if (
|
||||
"tensor_meta" not in node.meta
|
||||
and node.op == "call_function"
|
||||
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
|
||||
):
|
||||
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
|
||||
users = node.users
|
||||
assert all(user.target is operator.getitem for user in users)
|
||||
saved_values.extend(users)
|
||||
else:
|
||||
backward_usages = [
|
||||
n for n in node.users if n.name not in forward_node_names
|
||||
]
|
||||
if "tensor_meta" in node.meta and all(
|
||||
is_sym_node(n) for n in backward_usages
|
||||
):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
else:
|
||||
saved_values.append(node)
|
||||
assert all(user.target == operator.getitem for user in node.users)
|
||||
continue
|
||||
if not must_recompute(node):
|
||||
saved_values.append(node)
|
||||
|
||||
saved_values = list(dict.fromkeys(saved_values).keys())
|
||||
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
||||
|
||||
return _extract_fwd_bwd_modules(
|
||||
if config._sync_decision_cross_ranks:
|
||||
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
|
||||
|
||||
if static_lifetime_input_nodes is None:
|
||||
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
|
||||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
@ -1127,6 +1122,24 @@ def default_partition(
|
||||
static_lifetime_input_nodes=static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if graph_has_recomputable_ops:
|
||||
if graph_has_recomputable_rng_ops:
|
||||
fw_module, bw_module = functionalize_rng_ops(
|
||||
joint_module, fw_module, bw_module, len(saved_sym_nodes)
|
||||
)
|
||||
bw_module = reordering_to_mimic_autograd_engine(bw_module)
|
||||
|
||||
# raise all getitem ops to as early as possible
|
||||
# this is helpful for memory, especially in the case of aot_eager backend
|
||||
fw_module = raise_getitems(fw_module)
|
||||
bw_module = raise_getitems(bw_module)
|
||||
|
||||
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
|
||||
if len(node_info.required_bw_nodes) > 0:
|
||||
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
|
||||
|
||||
return fw_module, bw_module
|
||||
|
||||
|
||||
INT_INF = int(1e6)
|
||||
|
||||
@ -1621,7 +1634,9 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
|
||||
break
|
||||
|
||||
|
||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
def cleanup_recompute_tags(
|
||||
joint_module: fx.GraphModule, *, is_default_partition: bool
|
||||
) -> fx.GraphModule:
|
||||
"""
|
||||
If there are two consecutive checkpointed blocks with no operator in
|
||||
between, we would still want to stash the tensor at the boundary of
|
||||
@ -1658,6 +1673,16 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
|
||||
# in forward graph outputs. With this, we can break the above circular dependency.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
elif (
|
||||
"ac_graph_id" not in node.meta
|
||||
and any(must_recompute(user) for user in node.users)
|
||||
and is_default_partition
|
||||
):
|
||||
# This node is not part of the AC region and a user is marked as recompute.
|
||||
# This means it's an input to the AC region and we should save it.
|
||||
# For ease of landing, gate this to default partitioner only, but we should think
|
||||
# about flipping the switch in general as well.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
return joint_module
|
||||
|
||||
|
||||
@ -2765,6 +2790,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
|
||||
return module
|
||||
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
|
||||
def min_cut_rematerialization_partition(
|
||||
joint_module: fx.GraphModule,
|
||||
_joint_inputs,
|
||||
@ -2813,68 +2891,16 @@ def min_cut_rematerialization_partition(
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
joint_module = cleanup_recompute_tags(joint_module)
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(
|
||||
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
|
||||
)
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
# networkx blows up on graphs with no required backward nodes
|
||||
# Since there's nothing to partition anyway, and the default partitioner can "handle"
|
||||
|
||||
@ -627,7 +627,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
num_warps={self.num_warps},
|
||||
filename=__file__,
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -2259,7 +2259,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
gpu: bool = True,
|
||||
cpp_definition: Optional[str] = None,
|
||||
):
|
||||
if config.triton.autotune_at_compile_time and gpu:
|
||||
if config.triton.autotune_at_compile_time:
|
||||
body = self._format_kernel_definition(
|
||||
kernel_name, kernel_body, metadata=metadata
|
||||
)
|
||||
@ -3745,13 +3745,6 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
|
||||
super().__init__()
|
||||
|
||||
root = self.get_root_graph()
|
||||
# Only generate auto-tuning block in the main graph
|
||||
self.kernel_autotune_defs = root.kernel_autotune_defs
|
||||
self.kernel_autotune_calls = root.kernel_autotune_calls
|
||||
# Only store kernel src to name mapping in the main graph
|
||||
self.src_to_kernel = root.src_to_kernel
|
||||
|
||||
def set_launcher_fn_name(self) -> None:
|
||||
# This sets up the name of the function containing the launcher code of
|
||||
# the subgraph.
|
||||
@ -3844,16 +3837,3 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
||||
# )
|
||||
self.parent_wrapper.write_get_raw_stream_header_once()
|
||||
|
||||
@cache_on_self
|
||||
def get_root_graph(self) -> PythonWrapperCodegen:
|
||||
root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self
|
||||
while isinstance(root, SubgraphPythonWrapperCodegen):
|
||||
root = root.parent_wrapper
|
||||
|
||||
assert isinstance(root, PythonWrapperCodegen)
|
||||
return root
|
||||
|
||||
def generate_and_run_autotune_block(self):
|
||||
# Only execute auto-tuning block in the main graph
|
||||
pass
|
||||
|
||||
@ -64,7 +64,6 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
)
|
||||
from torch.fx.node import Node
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._python_dispatch import _disable_current_modes
|
||||
from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing
|
||||
from torch.utils._sympy.symbol import SymT
|
||||
|
||||
@ -6136,12 +6135,9 @@ class ExternKernel(InputsKernel):
|
||||
if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
|
||||
return ShapeAsConstantBuffer(expr=x)
|
||||
if isinstance(x, Constant):
|
||||
# We need to unset fake mode, or else the torch.tensor() call will
|
||||
# turn into a FakeTensor
|
||||
with _disable_current_modes():
|
||||
return V.graph.add_tensor_constant(
|
||||
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
|
||||
)
|
||||
return V.graph.add_tensor_constant(
|
||||
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
|
||||
)
|
||||
if isinstance(x, ConstantBuffer):
|
||||
return x
|
||||
if isinstance(x, TensorBox):
|
||||
|
||||
@ -7099,19 +7099,13 @@ def sym_constrain_range(a, min=None, max=None):
|
||||
@register_lowering(aten.sym_size.int)
|
||||
def sym_size(a, dim):
|
||||
val = V.graph.current_node.meta["val"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
return val.node.expr
|
||||
else:
|
||||
return int(val)
|
||||
return val.node.expr
|
||||
|
||||
|
||||
@register_lowering(aten.sym_stride.int)
|
||||
def sym_stride(a, dim):
|
||||
val = V.graph.current_node.meta["val"]
|
||||
if isinstance(val, torch.SymInt):
|
||||
return val.node.expr
|
||||
else:
|
||||
return int(val)
|
||||
return val.node.expr
|
||||
|
||||
|
||||
@register_lowering(aten.sym_numel)
|
||||
|
||||
@ -3607,13 +3607,24 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
configs,
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -702,7 +702,7 @@ def exp2(a):
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
@out_wrapper()
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a",),
|
||||
type_promoting_args=("a,"),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
|
||||
)
|
||||
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/csrc/stable/device_struct.h>
|
||||
@ -119,7 +120,7 @@ struct FromImpl<ScalarType> {
|
||||
case ScalarType::UInt64:
|
||||
return from(aoti_torch_dtype_uint64());
|
||||
default:
|
||||
STD_TORCH_CHECK(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported ScalarType, please file an issue describing your use case.");
|
||||
}
|
||||
@ -150,7 +151,7 @@ struct FromImpl<DeviceType> {
|
||||
case DeviceType::PrivateUse1:
|
||||
return from(aoti_torch_device_type_privateuse1());
|
||||
default:
|
||||
STD_TORCH_CHECK(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported DeviceType, please file an issue describing your use case.");
|
||||
}
|
||||
@ -378,7 +379,7 @@ struct ToImpl<ScalarType> {
|
||||
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
|
||||
return ScalarType::UInt64;
|
||||
} else {
|
||||
STD_TORCH_CHECK(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported ScalarType ",
|
||||
std::to_string(shim_scalartype),
|
||||
@ -408,7 +409,7 @@ struct ToImpl<DeviceType> {
|
||||
} else if (shim_devicetype == aoti_torch_device_type_privateuse1()) {
|
||||
return DeviceType::PrivateUse1;
|
||||
} else {
|
||||
STD_TORCH_CHECK(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported DeviceType ",
|
||||
std::to_string(shim_devicetype),
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
|
||||
|
||||
import torch
|
||||
@ -17,9 +17,6 @@ from torch.utils._pytree import tree_map_only
|
||||
from torch.utils.weak import WeakIdKeyDictionary, weakref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
_TOTAL_KEY = "Total"
|
||||
|
||||
__all__ = ["FSDPMemTracker"]
|
||||
@ -368,28 +365,14 @@ class FSDPMemTracker(MemTracker):
|
||||
# `FSDPParamGroup.post_forward` because during AC these won't be called.
|
||||
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
|
||||
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
|
||||
|
||||
# get the unique _MultiHandlers/RemoveHandlers and store in dictionary
|
||||
# the _MultiHandlers object will only need to be grabbed once.
|
||||
unique_handlers: dict[RemovableHandle, bool] = {}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for module in self._root_mod.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
fsdp_state = module._get_fsdp_state()
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
if not unique_handlers.get(fsdp_state._pre_forward_hook_handle):
|
||||
unique_handlers[fsdp_state._pre_forward_hook_handle] = True
|
||||
if not unique_handlers.get(fsdp_state._post_forward_hook_handle):
|
||||
unique_handlers[fsdp_state._post_forward_hook_handle] = True
|
||||
# call remove on the handles once
|
||||
for f_hook_handle in unique_handlers.keys():
|
||||
f_hook_handle.remove()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for module in self._root_mod.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
fsdp_state = module._get_fsdp_state()
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
self._instrument_fsdp_sharded_params_grads(fsdp_param_group)
|
||||
fsdp_state._pre_forward_hook_handle.remove()
|
||||
fsdp_state._post_forward_hook_handle.remove()
|
||||
fsdp_state._pre_forward_hook_handle = (
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
module.register_forward_pre_hook(
|
||||
|
||||
@ -194,10 +194,6 @@ else:
|
||||
_rank_map: Optional[torch.Tensor] = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> None:
|
||||
# no-op in OSS, logs API usage metrics in meta-internal runs
|
||||
torch._C._log_api_usage_once(
|
||||
"torch.distributed.device_mesh.DeviceMesh.__init__"
|
||||
)
|
||||
if mesh is not None:
|
||||
if _layout is not None or _rank_map is not None:
|
||||
raise TypeError(
|
||||
@ -259,13 +255,14 @@ else:
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_rank_map = tuple(self._rank_map.tolist())
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._thread_id = None
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
# Skip process group initialization if xla device or init backend is False
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
self._thread_id = None
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
@ -296,6 +293,11 @@ else:
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_rank_map = tuple(self._rank_map.tolist())
|
||||
# Initialize instance-specific flatten mapping
|
||||
self._flatten_mapping = {}
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
"""Returns the device type of the mesh."""
|
||||
|
||||
@ -359,10 +359,6 @@ class ShardingPropagator:
|
||||
"""
|
||||
Propagate the sharding for an operator given the op_schema.
|
||||
"""
|
||||
# no-op in OSS, logs API usage metrics in meta-internal runs
|
||||
torch._C._log_api_usage_once(
|
||||
"torch.distributed.tensor._sharding_prop.ShardingPropagator.propogate_op_sharding_non_cached"
|
||||
)
|
||||
# special case op, we don't need to propagate for local
|
||||
# scalar. TODO: figure out a better way to handle this
|
||||
if op_schema.op is aten._local_scalar_dense.default:
|
||||
|
||||
@ -398,9 +398,6 @@ def load(
|
||||
Under active development, saved files may not be usable in newer versions
|
||||
of PyTorch.
|
||||
|
||||
.. warning::
|
||||
:func:`torch.export.load()` uses pickle under the hood to load models. **Never load data from an untrusted source.**
|
||||
|
||||
Loads an :class:`ExportedProgram` previously saved with
|
||||
:func:`torch.export.save <torch.export.save>`.
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from collections.abc import Callable
|
||||
from contextlib import contextmanager, ExitStack, nullcontext
|
||||
from itertools import chain
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
|
||||
from unittest import mock
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -274,6 +275,24 @@ def _extract_fake_inputs(gm, args, kwargs):
|
||||
else:
|
||||
fake_vals.append(node.meta.get("example_value"))
|
||||
|
||||
if in_shuffle_graph := getattr(gm, "_in_shuffle_graph", None):
|
||||
flat_args = pytree.tree_leaves((args, kwargs))
|
||||
node_map = {
|
||||
node: i
|
||||
for i, node in enumerate(
|
||||
next(iter(reversed(in_shuffle_graph.graph.nodes))).args[0]
|
||||
)
|
||||
if node.op == "placeholder"
|
||||
}
|
||||
new_fake_inps: list[Any] = []
|
||||
for i, node in enumerate(
|
||||
in_shuffle_graph.graph.find_nodes(op="placeholder")[1:]
|
||||
):
|
||||
if node in node_map:
|
||||
new_fake_inps.append(fake_inps[node_map[node]])
|
||||
else:
|
||||
new_fake_inps.append(flat_args[i])
|
||||
fake_inps = new_fake_inps
|
||||
# We get both because now we might have a combination of symint and tensor
|
||||
# inputs, and we want to check that the shape env is consistent between
|
||||
# both. Unfortunately we can't see what fake mode is attached to the shape
|
||||
@ -798,6 +817,16 @@ def _export_to_torch_ir(
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
|
||||
def use_legacy_dynamo_graph_capture() -> bool:
|
||||
return bool(
|
||||
constraints # dynamic shape
|
||||
or dynamic_shapes # dynamic shape
|
||||
or isinstance(f, torch.fx.GraphModule) # retracing
|
||||
or preserve_module_call_signature # unflatten
|
||||
or torch._functorch.config.fake_tensor_propagate_real_tensors # draft
|
||||
or torch._export.config.use_legacy_dynamo_graph_capture
|
||||
)
|
||||
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)):
|
||||
try:
|
||||
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = (
|
||||
@ -812,11 +841,20 @@ def _export_to_torch_ir(
|
||||
if torch._export.config.use_new_tracer_experimental:
|
||||
from torch._dynamo.functional_export import (
|
||||
_dynamo_graph_capture_for_export,
|
||||
dynamo_graph_capture_for_export,
|
||||
)
|
||||
|
||||
gm_torch_level = _dynamo_graph_capture_for_export(
|
||||
f, constraints=constraints, dynamic_shapes=dynamic_shapes
|
||||
)(*args, **kwargs)
|
||||
if use_legacy_dynamo_graph_capture():
|
||||
dynamo_graph_capture = _dynamo_graph_capture_for_export(
|
||||
f, constraints=constraints, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
else:
|
||||
dynamo_graph_capture = dynamo_graph_capture_for_export(f)
|
||||
# We can't serialize entire fake mode yet, so this is to make sure
|
||||
# things like copy.deepcopy(ep.graph_module) not crash.
|
||||
# see test_export.py::test_custom_tag_metadata_re_export
|
||||
# Once we delete the old strict export, we can use
|
||||
gm_torch_level = dynamo_graph_capture(*args, **kwargs)
|
||||
# We can't serialize entire fake mode yet, so this is to make sure
|
||||
# things like copy.deepcopy(ep.graph_module) not crash.
|
||||
# see test_export.py::test_custom_tag_metadata_re_export
|
||||
@ -1568,7 +1606,11 @@ def _strict_export(
|
||||
}
|
||||
|
||||
tx = TracingContext(dynamo_fake_mode)
|
||||
with dynamo_fake_mode, tracing(tx):
|
||||
with (
|
||||
dynamo_fake_mode,
|
||||
tracing(tx),
|
||||
mock.patch.object(dynamo_fake_mode, "allow_non_fake_inputs", True),
|
||||
):
|
||||
aten_export_artifact = _to_aten_func(
|
||||
gm_torch_level,
|
||||
# NOTE: graph module expects only positional args
|
||||
|
||||
@ -3,7 +3,7 @@ import dataclasses
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
@ -721,18 +721,7 @@ def _combine_args(f, args, kwargs) -> dict[str, Any]:
|
||||
else inspect.signature(f)
|
||||
)
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
combined_args = signature.bind(*args, **kwargs).arguments
|
||||
# if `args` is in the key, flatten it into args_0, args_1, ...
|
||||
if "args" in combined_args:
|
||||
flattened_args = {f"args_{i}": v for i, v in enumerate(combined_args["args"])}
|
||||
combined_args = OrderedDict({**combined_args, **flattened_args})
|
||||
del combined_args["args"]
|
||||
# flatten kwargs into combined_args
|
||||
if "kwargs" in combined_args:
|
||||
for k, v in combined_args["kwargs"].items():
|
||||
combined_args[k] = v
|
||||
del combined_args["kwargs"]
|
||||
return combined_args
|
||||
return signature.bind(*args, **kwargs).arguments
|
||||
|
||||
|
||||
class ShapesCollection:
|
||||
|
||||
@ -1709,8 +1709,11 @@ def _convert_guards_to_code(graph_module):
|
||||
py_printer = torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter(
|
||||
shape_env.var_to_sources, lambda s: s.name(), shape_env.var_to_sources
|
||||
)
|
||||
return [
|
||||
ret = [
|
||||
py_printer.doprint(guard.expr)
|
||||
for guard in shape_env.guards
|
||||
if guard.expr.free_symbols.issubset(local_vars)
|
||||
]
|
||||
# TODO Figure out how to resolve guards containing weight sizes.
|
||||
# This is not a big deal as _guards_code is mostly empty today.
|
||||
return [guard for guard in ret if "L['self']" not in guard]
|
||||
|
||||
@ -19,13 +19,8 @@ __all__: list[str] = [
|
||||
"SDPBackend",
|
||||
"sdpa_kernel",
|
||||
"WARN_FOR_UNFUSED_KERNELS",
|
||||
"register_flash_attention_impl",
|
||||
"activate_flash_attention_impl",
|
||||
"list_flash_attention_impls",
|
||||
"current_flash_attention_impl",
|
||||
]
|
||||
|
||||
|
||||
# Note: [SDPA warnings]
|
||||
# TODO: Consider using this for sdpa regardless of subclasses
|
||||
# This only effects users of bias subclasses
|
||||
@ -167,23 +162,3 @@ def _sdpa_kernel_variadic(*backends: SDPBackend):
|
||||
def _get_flash_version() -> str:
|
||||
"""This returns the closest matching tag for the flash attention backend"""
|
||||
return "2.5.7"
|
||||
|
||||
|
||||
from . import _registry
|
||||
|
||||
|
||||
# Re-export registry types and functions for public API
|
||||
_FlashAttentionImpl = _registry._FlashAttentionImpl
|
||||
_RegisterFn = _registry._RegisterFn
|
||||
register_flash_attention_impl = _registry.register_flash_attention_impl
|
||||
activate_flash_attention_impl = _registry.activate_flash_attention_impl
|
||||
list_flash_attention_impls = _registry.list_flash_attention_impls
|
||||
current_flash_attention_impl = _registry.current_flash_attention_impl
|
||||
|
||||
register_flash_attention_impl.__module__ = __name__
|
||||
activate_flash_attention_impl.__module__ = __name__
|
||||
list_flash_attention_impls.__module__ = __name__
|
||||
current_flash_attention_impl.__module__ = __name__
|
||||
|
||||
# Import built-in implementations to trigger self-registration
|
||||
from . import _fa4 # noqa: F401
|
||||
|
||||
@ -1,444 +0,0 @@
|
||||
"""UBER PROTOTYPE!!!"""
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
from . import _registry
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
|
||||
__all__ = [
|
||||
"register_flash_attention_fa4",
|
||||
]
|
||||
|
||||
|
||||
_FA4_MODULE_PATH: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FA4Handle:
|
||||
library: Library | None
|
||||
|
||||
def remove(self) -> None:
|
||||
self.library = None
|
||||
|
||||
|
||||
@cache
|
||||
def _get_device_major(device: torch.device) -> int:
|
||||
major, _ = torch.cuda.get_device_capability(device)
|
||||
return major
|
||||
|
||||
|
||||
def register_flash_attention_fa4(
|
||||
module_path: str = "flash_attn.cute.interface",
|
||||
) -> _FA4Handle:
|
||||
"""
|
||||
Register FA4 flash attention kernels with the PyTorch dispatcher.
|
||||
|
||||
Args:
|
||||
module_path: Python module path to the FA4 implementation.
|
||||
"""
|
||||
global _FA4_MODULE_PATH
|
||||
_ = _fa4_import_module(module_path)
|
||||
_FA4_MODULE_PATH = module_path
|
||||
return _FA4Handle(_fa4_register_kernels())
|
||||
|
||||
|
||||
@cache
|
||||
def _fa4_import_module(module_path: str) -> ModuleType:
|
||||
module = importlib.import_module(module_path)
|
||||
if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
|
||||
raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
|
||||
return module
|
||||
|
||||
|
||||
def _fa4_register_kernels() -> Library:
|
||||
lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
|
||||
lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
|
||||
lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
|
||||
lib.impl(
|
||||
"_scaled_dot_product_flash_attention",
|
||||
_fa4_scaled_dot_product_flash_attention_forward_impl,
|
||||
"CUDA",
|
||||
)
|
||||
lib.impl(
|
||||
"_scaled_dot_product_flash_attention_backward",
|
||||
_fa4_scaled_dot_product_flash_attention_backward_impl,
|
||||
"CUDA",
|
||||
)
|
||||
return lib
|
||||
|
||||
|
||||
def _fa4_common_support_error(
|
||||
query: torch.Tensor,
|
||||
tensors: tuple[torch.Tensor, ...],
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
|
||||
) -> str | None:
|
||||
if not all(t.is_cuda for t in tensors):
|
||||
return "inputs must be CUDA tensors"
|
||||
if len({t.device for t in tensors}) != 1:
|
||||
return "inputs must share device"
|
||||
if query.dtype not in (torch.float16, torch.bfloat16):
|
||||
return "query dtype must be float16 or bfloat16"
|
||||
for name, tensor in require_fp32:
|
||||
if tensor.dtype != torch.float32:
|
||||
return f"{name} dtype must be float32"
|
||||
if cum_seq_q is None and query.dim() != 4:
|
||||
return "dense query must be 4D"
|
||||
if cum_seq_q is not None and query.dim() != 3:
|
||||
return "ragged query must be 3D"
|
||||
if not torch.cuda.is_available():
|
||||
return "CUDA not available"
|
||||
if _get_device_major(query.device) not in (9, 10):
|
||||
return "FA4 requires compute capability 9.0 or 10.0"
|
||||
return None
|
||||
|
||||
|
||||
def _fa4_forward_support_error(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float,
|
||||
return_debug_mask: bool,
|
||||
alibi_slopes: torch.Tensor | None,
|
||||
seqused_k: torch.Tensor | None,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
) -> str | None:
|
||||
if dropout_p != 0.0:
|
||||
return "dropout_p must be 0"
|
||||
if return_debug_mask:
|
||||
return "return_debug_mask must be False"
|
||||
if alibi_slopes is not None:
|
||||
return "alibi_slopes not supported"
|
||||
if seqused_k is not None:
|
||||
if seqused_k.dtype != torch.int32:
|
||||
return "seqused_k must be int32"
|
||||
if not seqused_k.is_cuda:
|
||||
return "seqused_k must be CUDA"
|
||||
error = _fa4_common_support_error(
|
||||
query,
|
||||
(query, key, value),
|
||||
cum_seq_q,
|
||||
)
|
||||
if error is not None:
|
||||
if error == "inputs must share device":
|
||||
return "query, key, value must be on same device"
|
||||
return error
|
||||
return None
|
||||
|
||||
|
||||
def _fa4_backward_support_error(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
dropout_p: float,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
window_size_left: int | None,
|
||||
window_size_right: int | None,
|
||||
) -> str | None:
|
||||
if dropout_p != 0.0:
|
||||
return "dropout_p must be 0"
|
||||
if window_size_left is not None or window_size_right is not None:
|
||||
return "windowed attention not supported"
|
||||
error = _fa4_common_support_error(
|
||||
query,
|
||||
(grad_out, query, key, value, out, logsumexp),
|
||||
cum_seq_q,
|
||||
require_fp32=(("logsumexp", logsumexp),),
|
||||
)
|
||||
if error is not None:
|
||||
return error
|
||||
return None
|
||||
|
||||
|
||||
Ts = TypeVarTuple("Ts")
|
||||
|
||||
|
||||
def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
|
||||
return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _fa4_run_forward(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor | None,
|
||||
cu_seq_k: torch.Tensor | None,
|
||||
scale: float | None,
|
||||
is_causal: bool,
|
||||
window_size_left: int | None,
|
||||
window_size_right: int | None,
|
||||
seqused_k: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if _FA4_MODULE_PATH is None:
|
||||
raise RuntimeError("FA4 not registered")
|
||||
module = _fa4_import_module(_FA4_MODULE_PATH)
|
||||
kwargs: dict[str, Any] = {
|
||||
"softmax_scale": scale,
|
||||
"causal": is_causal,
|
||||
"window_size_left": window_size_left,
|
||||
"window_size_right": window_size_right,
|
||||
"return_lse": True,
|
||||
"cu_seqlens_q": cu_seq_q,
|
||||
"cu_seqlens_k": cu_seq_k,
|
||||
"seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
|
||||
}
|
||||
out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
|
||||
return out, lse.contiguous()
|
||||
|
||||
|
||||
def _fa4_run_backward(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
cu_seq_q: torch.Tensor | None,
|
||||
cu_seq_k: torch.Tensor | None,
|
||||
scale: float | None,
|
||||
is_causal: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if _FA4_MODULE_PATH is None:
|
||||
raise RuntimeError("FA4 not registered")
|
||||
module = _fa4_import_module(_FA4_MODULE_PATH)
|
||||
dq, dk, dv = module._flash_attn_bwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
grad_out,
|
||||
logsumexp.contiguous(),
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
cu_seqlens_q=cu_seq_q,
|
||||
cu_seqlens_k=cu_seq_k,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
def _fa4_flash_attention_forward_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
cum_seq_k: torch.Tensor | None,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
return_debug_mask: bool,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
window_size_left: int | None = None,
|
||||
window_size_right: int | None = None,
|
||||
seqused_k: torch.Tensor | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
):
|
||||
error = _fa4_forward_support_error(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
return_debug_mask,
|
||||
alibi_slopes,
|
||||
seqused_k,
|
||||
cum_seq_q,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
|
||||
out, lse = _fa4_run_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
seqused_k,
|
||||
)
|
||||
rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
|
||||
philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
|
||||
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
|
||||
return out, lse, rng_state, philox_offset, debug_mask
|
||||
|
||||
|
||||
def _fa4_flash_attention_backward_impl(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
cum_seq_k: torch.Tensor | None,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
rng_state: torch.Tensor,
|
||||
unused: torch.Tensor,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
window_size_left: int | None = None,
|
||||
window_size_right: int | None = None,
|
||||
):
|
||||
error = _fa4_backward_support_error(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
dropout_p,
|
||||
cum_seq_q,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
|
||||
dq, dk, dv = _fa4_run_backward(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
scale,
|
||||
is_causal,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
def _fa4_scaled_dot_product_flash_attention_forward_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
return_debug_mask: bool = False,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
):
|
||||
error = _fa4_forward_support_error(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
return_debug_mask,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
|
||||
q, k, v = _transpose_dense(query, key, value)
|
||||
|
||||
max_q_flash = q.size(1)
|
||||
max_k_flash = k.size(1)
|
||||
out, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
max_q_flash,
|
||||
max_k_flash,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
return_debug_mask,
|
||||
scale=scale,
|
||||
)
|
||||
(out,) = _transpose_dense(out)
|
||||
max_q = query.size(2)
|
||||
max_k = key.size(2)
|
||||
return (
|
||||
out,
|
||||
lse,
|
||||
None,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
rng_state,
|
||||
philox_offset,
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
|
||||
def _fa4_scaled_dot_product_flash_attention_backward_impl(
|
||||
grad_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
cum_seq_q: torch.Tensor | None,
|
||||
cum_seq_k: torch.Tensor | None,
|
||||
max_q: int,
|
||||
max_k: int,
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
philox_seed: torch.Tensor,
|
||||
philox_offset: torch.Tensor,
|
||||
*,
|
||||
scale: float | None = None,
|
||||
):
|
||||
error = _fa4_backward_support_error(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
logsumexp,
|
||||
dropout_p,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
if error is not None:
|
||||
raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
|
||||
q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
|
||||
max_q = query.size(2)
|
||||
max_k = key.size(2)
|
||||
dq, dk, dv = _fa4_flash_attention_backward_impl(
|
||||
go,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
logsumexp,
|
||||
None,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
scale=scale,
|
||||
)
|
||||
dq, dk, dv = _transpose_dense(dq, dk, dv)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
_registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)
|
||||
@ -1,108 +0,0 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Registry for flash attention implementations.
|
||||
|
||||
This module contains the registration system for flash attention implementations.
|
||||
It has no torch dependencies to avoid circular imports during initialization.
|
||||
"""
|
||||
|
||||
from typing import Callable, Literal, Protocol
|
||||
|
||||
|
||||
class FlashAttentionHandle(Protocol):
|
||||
def remove(self) -> None: ...
|
||||
|
||||
|
||||
_RegisterFn = Callable[..., FlashAttentionHandle | None]
|
||||
_FlashAttentionImpl = Literal["FA4"]
|
||||
|
||||
_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
|
||||
|
||||
_FLASH_ATTENTION_ACTIVE: str | None = None
|
||||
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}
|
||||
|
||||
|
||||
def register_flash_attention_impl(
|
||||
impl: str | _FlashAttentionImpl,
|
||||
*,
|
||||
register_fn: _RegisterFn,
|
||||
) -> None:
|
||||
"""
|
||||
Register the callable that activates a flash attention impl.
|
||||
|
||||
.. note::
|
||||
This function is intended for SDPA backend providers to register their
|
||||
implementations. End users should use :func:`activate_flash_attention_impl`
|
||||
to activate a registered implementation.
|
||||
|
||||
Args:
|
||||
impl: Implementation identifier (e.g., ``"FA4"``).
|
||||
register_fn: Callable that performs the actual dispatcher registration.
|
||||
This function will be invoked by :func:`activate_flash_attention_impl`
|
||||
and should register custom kernels with the PyTorch dispatcher.
|
||||
It may optionally return a handle implementing
|
||||
:class:`FlashAttentionHandle` to keep any necessary state alive.
|
||||
|
||||
Example:
|
||||
>>> def my_impl_register(module_path: str = "my_flash_impl"):
|
||||
... # Register custom kernels with torch dispatcher
|
||||
... pass # doctest: +SKIP
|
||||
>>> register_flash_attention_impl(
|
||||
... "MyImpl", register_fn=my_impl_register
|
||||
... ) # doctest: +SKIP
|
||||
"""
|
||||
_FLASH_ATTENTION_IMPLS[impl] = register_fn
|
||||
|
||||
|
||||
def activate_flash_attention_impl(
|
||||
impl: str | _FlashAttentionImpl,
|
||||
) -> None:
|
||||
"""
|
||||
Activate into the dispatcher a previously registered flash attention impl.
|
||||
|
||||
.. note::
|
||||
Backend providers should NOT automatically activate their implementation
|
||||
on import. Users should explicitly opt-in by calling this function or via
|
||||
environment variables to ensure multiple provider libraries can coexist.
|
||||
|
||||
Args:
|
||||
impl: Implementation identifier to activate. See
|
||||
:func:`~torch.nn.attention.list_flash_attention_impls` for available
|
||||
implementations.
|
||||
If the backend's :func:`register_flash_attention_impl` callable
|
||||
returns a :class:`FlashAttentionHandle`, the registry keeps that
|
||||
handle alive for the lifetime of the process (until explicit
|
||||
uninstall support exists).
|
||||
|
||||
Example:
|
||||
>>> activate_flash_attention_impl("FA4") # doctest: +SKIP
|
||||
"""
|
||||
global _FLASH_ATTENTION_ACTIVE
|
||||
register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
|
||||
if register_fn is None:
|
||||
raise ValueError(
|
||||
f"Unknown flash attention impl '{impl}'. "
|
||||
f"Available implementations: {list_flash_attention_impls()}"
|
||||
)
|
||||
# TODO: The only way to actually register a new impl is to unregister the current impl
|
||||
# reinstall the default impl and then register the new impl
|
||||
if _FLASH_ATTENTION_ACTIVE == impl:
|
||||
return
|
||||
|
||||
handle = register_fn()
|
||||
if handle is not None:
|
||||
_FLASH_ATTENTION_HANDLES[impl] = handle
|
||||
_FLASH_ATTENTION_ACTIVE = impl
|
||||
|
||||
|
||||
def list_flash_attention_impls() -> list[str]:
|
||||
"""Return the names of all available flash attention implementations."""
|
||||
return sorted(_FLASH_ATTENTION_IMPLS.keys())
|
||||
|
||||
|
||||
def current_flash_attention_impl() -> str | None:
|
||||
"""
|
||||
Return the currently activated flash attention impl name, if any.
|
||||
|
||||
``None`` indicates that no custom impl has been activated.
|
||||
"""
|
||||
return _FLASH_ATTENTION_ACTIVE
|
||||
Reference in New Issue
Block a user