Compare commits

..

1 Commits

Author SHA1 Message Date
1d97899283 [export] stop gap strict export v2 enable and testing.
Summary:
Added a new flag called "use_legacy_dynamo_graph_capture" which defaults to True and only False with the updated test_strict_export_v2.py

In addiotion to this flag, we also use legacy tracer when the following features are used:
1. dynamic shape
2. preserve module call signature
3. retracing.
4. draft mode.

Test Plan:
test_strict_export_v2.py
2025-11-11 10:25:13 -08:00
55 changed files with 625 additions and 1399 deletions

View File

@ -96,6 +96,7 @@ function pip_build_and_install() {
python3 -m pip wheel \
--no-build-isolation \
--no-deps \
--no-use-pep517 \
-w "${wheel_dir}" \
"${build_target}"
fi

View File

@ -63,7 +63,7 @@ self-hosted-runner:
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4
- linux.rocm.gfx942.docker-cache
- rocm-docker
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
- macos-m1-stable
- macos-m1-14

View File

@ -0,0 +1,55 @@
name: docker-cache-mi300
on:
# run every 6 hours
schedule:
- cron: 0 0,6,12,18 * * *
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
docker-cache:
if: github.repository_owner == 'pytorch'
runs-on: rocm-docker
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
with:
no-sudo: true
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
push: false
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Tar and upload to S3 bucket
run: |
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress

View File

@ -1,108 +0,0 @@
name: docker-cache-rocm
on:
workflow_run:
workflows: [docker-builds]
# TODO: Uncomment before merging
#branches: [main, release]
types:
- completed
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
actions: read
jobs:
download-docker-builds-artifacts:
if: github.repository_owner == 'pytorch'
name: download-docker-builds-artifacts
runs-on: ubuntu-latest
outputs:
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
steps:
- name: Download artifacts
uses: actions/download-artifact@v4.1.7
with:
run-id: ${{ github.event.workflow_run.id }}
path: ./docker-builds-artifacts
merge-multiple: true
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Process artifacts
id: process-artifacts
run: |
ls -R ./docker-builds-artifacts
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
cat "${GITHUB_OUTPUT}"
docker-cache:
if: github.repository_owner == 'pytorch'
needs: download-docker-builds-artifacts
strategy:
fail-fast: false
matrix:
runner: [linux.rocm.gfx942.docker-cache]
docker-image: [
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
]
runs-on: "${{ matrix.runner }}"
steps:
- name: debug
run: |
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Generate ghrc.io tag
id: ghcr-io-tag
run: |
ecr_image="${{ matrix.docker-image }}"
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
- name: Save as tarball
run: |
docker_image_tag=${{ matrix.docker-image }}
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
ref_name=${{ github.event.workflow_run.head_branch }}
if [[ $ref_name =~ "release/" ]]; then
ref_suffix="release"
elif [[ $ref_name == "main" ]]; then
ref_suffix="main"
else
# TODO: Remove below
ref_suffix="main"
# echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
fi
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar

View File

@ -142,7 +142,6 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
auto batch_sizes_t = _batch_sizes.contiguous();
checkLongTensor(batch_sizes_t);
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
int64_t max_batch_size = batch_sizes[0];

View File

@ -669,12 +669,9 @@ std::optional<c10::ScalarType> out_dtype) {
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
bool use_fast_path = false;
// On non CK system(w/ ROCm), make sure use_fast_path is false
#if defined(USE_ROCM_CK_GEMM)
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
use_fast_path = true;
}
#endif //USE_ROCM_CK_GEMM
#endif
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
@ -683,11 +680,7 @@ std::optional<c10::ScalarType> out_dtype) {
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
#if defined(USE_ROCM_CK_GEMM)
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#else
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
#endif //USE_ROCM_CK_GEMM
#endif
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);

View File

@ -1426,9 +1426,6 @@ static at::Tensor _fp8_convolution_onednn_ref(
w_scales_new_shape[0] = -1;
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
auto output_padding = std::vector<int64_t>(kSpatialDim, 0);
if (bias.has_value()){
bias = bias.value().to(at::kFloat);
}
auto y_f32 = at::convolution(
dqx, dqw, bias, stride.vec(), padding.vec(), dilation.vec(), /* transposed */false, output_padding, groups
);

View File

@ -47,7 +47,6 @@
#include <c10/macros/Macros.h>
#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/distance.h>
#include <thrust/for_each.h>
#include <thrust/functional.h>
#include <thrust/gather.h>

View File

@ -50,7 +50,7 @@ nfnet_l0,pass,7
repvgg_a2,pass,7
repvgg_a2,fail_accuracy,7

1 name accuracy graph_breaks
50
51
52
53
54
55
56

View File

@ -14,10 +14,6 @@ Utils
sdpa_kernel
SDPBackend
register_flash_attention_impl
activate_flash_attention_impl
list_flash_attention_impls
current_flash_attention_impl
Submodules
----------

View File

@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
pip install ninja
# Install onnx
pip install -e "$tp2_dir/onnx"
pip install --no-use-pep517 -e "$tp2_dir/onnx"
# Install caffe2 and pytorch
pip install -r "$top_dir/caffe2/requirements.txt"

View File

@ -140,11 +140,6 @@ static void initDeviceStreamState(DeviceIndex device_index) {
static void initOpenRegStreamsOnce() {
c10::call_once(init_flag, initGlobalStreamState);
for (const auto i : c10::irange(num_devices)) {
c10::call_once(
device_flags[i], initDeviceStreamState, static_cast<DeviceIndex>(i));
}
if (current_streams) {
return;
}
@ -207,6 +202,8 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
if (device_index == -1) {
device_index = current_device();
}
c10::call_once(
device_flags[device_index], initDeviceStreamState, device_index);
auto pri_idx =
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
const auto idx = get_idx(priority_counters[device_index][pri_idx]);

View File

@ -180,47 +180,6 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
del model
del optim
def _test_tracker_multihandler_hook(self):
"""Should run without KeyError."""
class TestModule(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.norm1 = nn.RMSNorm(dim)
self.output1 = nn.Linear(dim, dim)
self.norm2 = nn.RMSNorm(dim)
self.output2 = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm1(x)
x = self.output1(x)
x = self.norm2(x)
x = self.output2(x)
return x
gc.collect()
torch.manual_seed(42)
dev = torch.device(torch.accelerator.current_device_index())
with torch.device(dev):
model = TestModule(128)
mesh = init_device_mesh(dev.type, (self.world_size,))
fully_shard([model.norm1, model.output1], mesh=mesh)
fully_shard([model.norm2, model.output2], mesh=mesh)
fully_shard(model, mesh=mesh)
fmt = FSDPMemTracker(model)
with fmt:
inp = torch.randn(16, 128, device=dev)
y = model(inp)
loss = y.sum()
loss.backward()
del inp
del model
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
@property

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import unittest
import torch
import torch.distributed as dist
@ -372,7 +371,6 @@ class DTensorExportTest(TestCase):
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
# is producing a joint graph with backward region missing
@unittest.expectedFailure
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)

View File

@ -15,7 +15,7 @@ import torch._functorch.config
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from functorch.compile import default_partition, min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
)
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.two_tensor import TwoTensor
@ -281,7 +281,14 @@ class ActivationCheckpointingViaTagsTests(
run(export_compiler)
def test_tags_function(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -297,11 +304,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_via_global_checkpoint(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -316,17 +334,28 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_function_with_kwargs(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_function_with_kwargs(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return torch.utils.checkpoint.checkpoint(
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
gn, torch.sin(x), y, use_reentrant=False
)
x = torch.randn(4, 4, device=device, requires_grad=True)
@ -336,11 +365,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=3, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_sequential_layers(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_sequential_layers(self, device, partition_fn):
def gn(x):
x = x.cos()
for _ in range(3):
@ -361,11 +401,22 @@ class ActivationCheckpointingViaTagsTests(
freqs=[2, 18],
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_multiple_checkpoints(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_multiple_checkpoints(self, device, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
@ -383,11 +434,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=6, op=torch.ops.aten.mm.default
) # mm recomputed in the bwd
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
@requires_cuda_and_triton
def test_tags_module(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_module(self, device, partition_fn):
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -411,11 +473,22 @@ class ActivationCheckpointingViaTagsTests(
bw_compiler = functools.partial(
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
)
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@requires_cuda_and_triton
def test_tags_decomps(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_tags_decomps(self, device, partition_fn):
# Ensures that tags are passed on through decompositions as well
class MockModule(torch.nn.Module):
def __init__(self) -> None:
@ -443,6 +516,7 @@ class ActivationCheckpointingViaTagsTests(
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=lambda: import_module(
"torch._inductor.compile_fx"
).select_decomp_table(),
@ -702,7 +776,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_recompute(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
def context_fn_must_recompute_mm():
must_recompute_list = [
torch.ops.aten.mm.default,
@ -723,9 +804,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
),
)
def _test(context_fn, bw_compiler):
def _test(context_fn, bw_compiler, partition_fn):
def gn(x):
return torch.sigmoid(torch.matmul(x, x))
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
def fn(x):
return torch.utils.checkpoint.checkpoint(
@ -739,14 +820,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
fw_compiler = functools.partial(
count_ops,
freq=1,
freq=2,
op=torch.ops.aten.mm.default,
)
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x)
@ -754,17 +835,19 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
context_fn=context_fn_must_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
_test(
context_fn=context_fn_no_recompute_mm,
bw_compiler=functools.partial(
count_ops,
freq=2, # 2 bwd mm ops per fwd matmul
freq=4, # 2 bwd mm ops per fwd matmul
op=torch.ops.aten.mm.default,
),
partition_fn=partition_fn,
)
def test_sac_with_partial_context_fn(self):
@ -801,7 +884,16 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm(
self, device, partition_fn
):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -841,15 +933,22 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
self, device
self, device, partition_fn
):
def selective_checkpointing_context_fn():
no_recompute_list = [
@ -889,7 +988,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
disable_functionalization=True,
)
self._validate(fn, backend, x, y)
@ -897,7 +996,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_triton_kernel(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
# Copy of the above test, but make sure that having a triton kernel in the
# region does not error.
def add_one(x):
@ -957,14 +1063,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_tensor_subclass(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1007,14 +1120,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_custom_rule(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
def _get_custom_policy(meta):
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1072,14 +1192,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
def selective_checkpointing_context_fn(no_recompute_list):
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
@ -1118,14 +1245,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_outplace_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1163,14 +1297,21 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_list_ops(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
def selective_checkpointing_context_fn():
# recompute everything
no_recompute_list = []
@ -1206,7 +1347,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1217,7 +1358,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
"requires TorchDispatchMode + torch.compile work to complete"
)
@requires_cuda_and_triton
def test_compile_selective_checkpoint_inplace_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
@ -1257,7 +1405,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
self._validate(fn, backend, x, y)
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
@ -1265,7 +1413,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
@torch._inductor.config.patch(fallback_random=True)
def test_compile_selective_checkpoint_random_op(self, device):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
for preserve_rng_state in [True, False]:
def selective_checkpointing_context_fn():
@ -1312,7 +1467,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
@ -1324,7 +1479,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
def test_compile_selective_checkpoint_invalid_context(self):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y)) * y
@ -1353,7 +1515,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
with self.assertRaisesRegex(
Exception, "must generate a tuple of two `TorchDispatchMode`s"
@ -1362,7 +1524,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
@requires_cuda_and_triton
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
def test_compile_selective_checkpoint_parametrization(self):
@parametrize(
"partition_fn",
[
min_cut_rematerialization_partition,
default_partition,
],
)
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
def sac_policy():
def _recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
@ -1425,7 +1594,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
bw_compiler = functools.partial(
count_ops,
freqs=[
2, # 1 from mul recompute, 1 from mul backward
# 1 from mul recompute, 1 from mul backward
# w/o CSE, we have one extra mul
3 if partition_fn is default_partition else 2,
1,
],
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
@ -1434,7 +1605,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
backend = aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=min_cut_rematerialization_partition,
partition_fn=partition_fn,
)
model = MLPModule()

View File

@ -2363,34 +2363,6 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(output, expected))
assert cnt.frame_count == 1
@unittest.skipIf(sys.version_info < (3, 13), "math.fma introduced in python 3.13")
def test_math_fma(self):
def fma_func(a, b, c):
return math.fma(a, b, c)
# Test with scalar constants (constant folding path)
cnt = torch._dynamo.testing.CompileCounter()
cfma_scalars = torch._dynamo.optimize_assert(cnt)(fma_func)
assert cnt.frame_count == 0
expected = fma_func(2.0, 3.0, 4.0)
output = cfma_scalars(2.0, 3.0, 4.0)
self.assertEqual(output, expected)
assert cnt.frame_count == 0
# Test with tensors (Inductor path)
cnt2 = torch._dynamo.testing.CompileCounter()
cfma_tensors = torch._dynamo.optimize_assert(cnt2)(fma_func)
assert cnt2.frame_count == 0
x = torch.tensor(2.0)
y = torch.tensor(3.0)
z = torch.tensor(4.0)
expected_tensors = x * y + z
output_tensors = cfma_tensors(x, y, z)
torch.testing.assert_close(output_tensors, expected_tensors)
assert cnt2.frame_count == 1
@make_test
def test_numpy_meshgrid(x, y):
r1, r2 = np.meshgrid(x.numpy(), y.numpy())

View File

@ -335,59 +335,6 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
@requires_multigpu()
def test_new_event_api(self) -> None:
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
from torch._dynamo.variables.streams import new_event
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
e0_ind = new_event()
with torch.Stream(device="cuda:1"):
get_external_object_by_index(e0_ind).record()
e1_ind = new_event()
self.assertNotEqual(e0_ind, e1_ind)
self.assertNotEqual(
get_external_object_by_index(e0_ind),
get_external_object_by_index(e1_ind),
)
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
gm.graph.call_function(
get_external_object_by_index, args=(1,), kwargs={}
)
return gm
@torch.compile(backend=event_generation_backend)
def fn(x):
return x + 1
fn(torch.ones(2, 2, device="cuda:0"))
@requires_cuda
def test_new_stream_api(self) -> None:
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
from torch._dynamo.variables.streams import new_stream
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
s0_ind = new_stream()
s1_ind = new_stream()
self.assertNotEqual(s0_ind, s1_ind)
self.assertNotEqual(
get_external_object_by_index(s0_ind),
get_external_object_by_index(s1_ind),
)
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
gm.graph.call_function(
get_external_object_by_index, args=(1,), kwargs={}
)
return gm
@torch.compile(backend=stream_generation_backend)
def fn(x):
return x + 1
fn(torch.ones(2, 2, device="cuda:0"))
@requires_cuda
def test_stream_with_mutation(self):
def fn(x, y):
@ -576,23 +523,6 @@ class <lambda>(torch.nn.Module):
torch.accelerator.set_stream(original_stream)
reset_user_object_tracking()
@requires_cuda
def test_run_opcheck_wait_record_stream(self):
from torch._dynamo.variables.streams import wait_stream
from torch.library import opcheck
s0 = torch.Stream()
s1 = torch.Stream()
s2 = torch.Stream()
store_user_object_weakrefs(s0, s1, s2)
sample_inputs = [
(0, 1),
(2, 0),
]
for args in sample_inputs:
opcheck(wait_stream, args)
@requires_cuda
def test_inductor_lowering(self):
with patch("torch._inductor.config.implicit_fallbacks", False):

View File

@ -331,12 +331,7 @@ class TestDynamismExpression(TestCase):
return torch.ops.aten.slice.Tensor(*args)
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
dynamic_shapes = (
{0: Dim("dim")},
None,
None,
None,
)
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
torch.export.export(
Slice(),
inp,
@ -590,6 +585,7 @@ class TestExport(TestCase):
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
self._test_export_same_as_eager(f, inp)
@testing.expectedFailureStrictV2
@skipIfCrossRef
def test_custom_tag_metadata_re_export(self):
class Foo(torch.nn.Module):
@ -1026,6 +1022,7 @@ graph():
dynamic_shapes = {"x": (dim0_x, dim1_x)}
export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
@testing.expectedFailureStrictV2
def test_no_tensor_computation(self):
class Module(torch.nn.Module):
def forward(self, x, y):
@ -1361,6 +1358,7 @@ def forward(self, primals, tangents):
# instead of the scripted function, so we get x.sin()
self.assertEqual(res, x.sin())
@testing.expectedFailureStrictV2
def test_no_tensor_computation_2(self):
class Module(torch.nn.Module):
def forward(self, x, y):
@ -1379,6 +1377,7 @@ graph():
return (x,)""",
)
@testing.expectedFailureStrictV2
def test_no_tensor_computation_3(self):
class Module(torch.nn.Module):
def forward(self, x, y):
@ -1397,6 +1396,7 @@ graph():
return (5,)""",
)
@testing.expectedFailureStrictV2
def test_no_tensor_computation_4(self):
class Module(torch.nn.Module):
def forward(self, x, y):
@ -1939,6 +1939,7 @@ graph():
for vr_upper in vr_upper_bounds:
self.assertEqual(vr_upper, 1)
@testing.expectedFailureStrictV2
def test_detect_leak_strict(self):
class Foo(torch.nn.Module):
def __init__(self):
@ -2687,6 +2688,7 @@ class GraphModule(torch.nn.Module):
gm = export(m, (torch.rand(64, 64),))
torch.export.unflatten(gm)
@testing.expectedFailureStrictV2
def test_unflatten_closure(self):
class Dummy(torch.nn.Module):
def forward(self, fn, x):
@ -4192,6 +4194,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
if str(sym) in ["u0", "s0"]:
self.assertEqual(vr.lower, 1)
@testing.expectedFailureStrictV2
def test_duplicate_modules_with_non_persistent_buffers(self):
class FooWithBuf(torch.nn.Module):
def __init__(self):
@ -4835,6 +4838,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
table.materialize()
self.assertFalse(torch.ops.mylib.foo123.default in table)
@testing.expectedFailureStrictV2
def test_if_post_autograd_op_preserved(self):
class Foo(torch.nn.Module):
def forward(self, x):
@ -5538,11 +5542,21 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
w = Wrapped()
compiled = export(w, args, dynamic_shapes=({0: batch}, {0: batch}))
expected = w(*args)
mod = compiled.module()
got = mod(*args)
self.assertTrue(torch.allclose(expected, got))
if is_retracebility_test(self._testMethodName):
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Detected mismatch between the structure of `inputs` and `dynamic_shapes`"
": `inputs` has 2 elements, but `dynamic_shapes` has 1 elements",
):
export(w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})})
else:
compiled = export(
w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})}
)
expected = w(*args)
mod = compiled.module()
got = mod(*args)
self.assertTrue(torch.allclose(expected, got))
def test_dynamic_shapes_builder_basic(self):
class M(torch.nn.Module):
@ -7223,6 +7237,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
@testing.expectedFailureSerDer # we don't save placeholder metadata
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureStrictV2
def test_linear_conv(self):
strict = True
@ -8821,6 +8836,7 @@ def forward(self, x):
)
)
@testing.expectedFailureStrictV2
def test_automatic_constrain_size(self):
class M(torch.nn.Module):
def forward(self, x, y):
@ -8932,6 +8948,7 @@ def forward(self, x):
):
ep.graph_module.while_loop_body_graph_0(torch.tensor([5]), torch.zeros(1))
@testing.expectedFailureStrictV2
def test_constrain_decomp(self) -> None:
class M(torch.nn.Module):
def __init__(self) -> None:
@ -9570,6 +9587,7 @@ def forward(self, b_a_buffer, x):
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
@requires_cuda_and_triton
@testing.expectedFailureStrictV2
def test_export_associative_scan_lifted_buffers(self):
if "cpp_runtime_nonstrict" in self.id():
self.skipTest("TODO Unexpected success in OSS but not in fbcode.")
@ -9660,6 +9678,7 @@ def forward(self, b_a_buffer, x):
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
)
@testing.expectedFailureStrictV2
def test_no_check_is_size_error(self):
class Module(torch.nn.Module):
def forward(self, x):
@ -9813,6 +9832,7 @@ def forward(self, b_a_buffer, x):
self.assertEqual(len(ep.graph_signature.input_specs), 4)
self.assertTrue(torch.allclose(ep.module()(*inp), transform.module()(*inp)))
@testing.expectedFailureStrictV2
def test_tensor_attribute_zero_args(self):
class Foo(torch.nn.Module):
def __init__(self, value):
@ -9826,6 +9846,7 @@ def forward(self, b_a_buffer, x):
ep = export(m, ())
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
@testing.expectedFailureStrictV2
def test_preserve_shape_dynamism_for_unused_inputs(self):
torch.export.register_dataclass(
Inp3,
@ -9995,6 +10016,7 @@ def forward(self, p_lin_weight, p_lin_bias, x):
)
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
@testing.expectedFailureStrictV2
def test_export_decomp_torture_case_2(self):
class MyLinear(torch.nn.Module):
def __init__(self) -> None:
@ -10130,6 +10152,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
# expected 4, but got 7
ep_v2.module()(*test_inp)
@testing.expectedFailureStrictV2
def test_constant_output(self):
class ModuleConstant(torch.nn.Module):
def __init__(self) -> None:
@ -10214,6 +10237,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
# expected >= 3, but got 2
ep.module()(*test_inp)
@testing.expectedFailureStrictV2
def test_nested_module(self):
class M1(torch.nn.Module):
def forward(self, x):
@ -10251,6 +10275,7 @@ graph():
unflattened = unflatten(ep)
self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps)))
@testing.expectedFailureStrictV2
def test_nested_module_with_init_buffer(self):
class M1(torch.nn.Module):
def __init__(self) -> None:
@ -10378,6 +10403,7 @@ graph():
ep = export(m, sample_inputs)
self.assertEqual(ep.module()(*sample_inputs), m(*sample_inputs))
@testing.expectedFailureStrictV2
def test_lazy_module_kwargs(self):
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
def initialize_parameters(self, *args, **kwargs):
@ -12251,6 +12277,7 @@ graph():
ep.module()(x)
@testing.expectedFailureCppRuntime
@testing.expectedFailureStrictV2
def test_symint_input_basic(self):
class M(torch.nn.Module):
def forward(self, x, y):
@ -12970,6 +12997,7 @@ def forward(self, c_submod_params, x):
ufm = torch.export.unflatten(ep)
self.assertTrue(torch.allclose(ufm(*inp), epm(*inp)))
@testing.expectedFailureStrictV2
def test_unflatten_multiple_graphs_shared_submodule(self):
class N(torch.nn.Module):
def forward(self, x, b):
@ -14021,6 +14049,7 @@ def forward(self, x):
return (foo_functional,)""",
)
@testing.expectedFailureStrictV2
def test_placeholder_naming_order(self):
# See https://github.com/pytorch/pytorch/issues/143732
@ -14072,6 +14101,7 @@ def forward(self, x):
).run_decompositions()
ep.module()(torch.ones(4, 4), **kwargs)
@testing.expectedFailureStrictV2
def test_placeholder_naming_order_variadic(self):
class Mod(torch.nn.Module):
def forward(self, a, b, c, **kwargs):
@ -14096,6 +14126,7 @@ def forward(self, x):
):
export(Foo(), (torch.randn(4, 4),), strict=False)
@testing.expectedFailureStrictV2
def test_placeholder_naming_collisions(self):
# test collisions between nested user inputs
class Foo(torch.nn.Module):
@ -14168,6 +14199,7 @@ def forward(self, x):
self.assertEqual(expected_names_and_ops, real_names_and_ops)
@skipIfCrossRef # Dynamo changes the order of ops under Torch function modes
@testing.expectedFailureStrictV2
def test_placeholder_naming_collisions_hoo_subgraphs(self):
# test collisions between user inputs, top-level nodes, and HOO subgraph nodes
class Foo(torch.nn.Module):
@ -14245,6 +14277,7 @@ def forward(self, x):
]
self.assertEqual(expected_getattr_names, real_getattr_names)
@testing.expectedFailureStrictV2
def test_constant_input_naming(self):
class Foo(torch.nn.Module):
def forward(self, x, y, div="floor"):
@ -14936,6 +14969,7 @@ graph():
]
self.assertEqual(len(repeat_nodes), 0)
@testing.expectedFailureStrictV2
def test_checks_to_constrain_range(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
@ -15270,6 +15304,7 @@ graph():
Block(torch.randn(4, 4), torch.randn(4, 4))
)
@testing.expectedFailureStrictV2
def test_enum_str(self):
class TensorDim(str, enum.Enum):
DDP = "ddp"
@ -15431,6 +15466,7 @@ def forward(self, x):
return (getitem_3, cos_1)""",
)
@testing.expectedFailureStrictV2
def test_run_decompositions_keep_metadata(self):
"""Make sure the metadata is kept after exported program run_decompositions."""
@ -15460,6 +15496,7 @@ def forward(self, x):
for node in decomposed_program.graph.nodes:
self.assertEqual(node.meta["custom"]["my_field"], "dummy")
@testing.expectedFailureStrictV2
def test_run_decompositions_keep_tensor_constant_metadata(self):
"""Make sure the metadata of tensor constants are kept after run_decompositions."""
@ -16091,6 +16128,7 @@ def forward(self, x):
@testing.expectedFailureSerDer # T195866111
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureStrictV2
def test_hints_wrapper(self):
strict = True
@ -16665,6 +16703,7 @@ def forward(self, args_0):
return (abs_1,)""",
)
@testing.expectedFailureStrictV2
def test_sdpa_gqa(self):
from torch.nn.attention import sdpa_kernel, SDPBackend
@ -17499,105 +17538,6 @@ def forward(self, x):
exported_param_names = [name for name, _ in gm.named_parameters()]
self.assertEqual(original_param_names, exported_param_names)
def test_export_compiled_model_with_nested_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, data_batch):
return data_batch["a1"] + data_batch["a2"]
m = M()
compiled_m = torch.compile(m)
example_args = (
{
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
},
)
dynamic_shapes = (
{
"a1": {0: Dim.DYNAMIC},
"a2": {0: Dim.DYNAMIC},
},
)
ep = export(
compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True
)
gm = ep.module()
self.assertEqual(gm(*example_args), compiled_m(*example_args))
def test_export_model_with_nested_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, data_batch):
return data_batch["a1"] + data_batch["a2"]
m = M()
example_args = (
{
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
},
)
B = torch.export.Dim("batch", min=1, max=65536)
dynamic_shapes = (
{
"a1": {0: B},
"a2": {0: B},
},
)
ep = export(m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
gm = ep.module()
self.assertEqual(gm(*example_args), m(*example_args))
def test_export_compiled_model_with_kwargs_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, a1, a2):
return a1 + a2
m = M()
compiled_m = torch.compile(m)
example_args = ()
example_kwargs = {
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
}
dynamic_shapes = {
"a1": {0: Dim.DYNAMIC},
"a2": {0: Dim.DYNAMIC},
}
ep = export(
compiled_m,
example_args,
kwargs=example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=True,
)
gm = ep.module()
self.assertEqual(gm(**example_kwargs), compiled_m(**example_kwargs))
def test_export_model_with_kwargs_dynamic_shapes(self):
class M(torch.nn.Module):
def forward(self, a1, a2):
return a1 + a2
m = M()
example_args = ()
example_kwargs = {
"a1": torch.ones(3, 3),
"a2": torch.ones(3, 3),
}
dynamic_shapes = {
"a1": {0: Dim.DYNAMIC},
"a2": {0: Dim.DYNAMIC},
}
ep = export(
m,
example_args,
kwargs=example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=True,
)
gm = ep.module()
self.assertEqual(gm(**example_kwargs), m(**example_kwargs))
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestExportCustomClass(TorchTestCase):

View File

@ -15,7 +15,7 @@ test_classes = {}
def mocked_strict_export_v2(*args, **kwargs):
# If user already specified strict, don't make it strict
with config.patch(use_new_tracer_experimental=True):
with config.patch(use_legacy_dynamo_graph_capture=False):
if "strict" in kwargs:
return export(*args, **kwargs)
return export(*args, **kwargs, strict=True)

View File

@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
return grad_output * x, grad_output * x
def f(a, b):
return FwBwMutation.apply(a, b)
return FwBwMutation.apply(a, b).sin_().clone()
inps = [
torch.ones(3, 3, requires_grad=True),
@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
return (mul, add)""",
clone = torch.ops.aten.clone.default(mul)
sin_ = torch.ops.aten.sin_.default(mul); mul = None
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
return (clone_1, add, clone)""",
)
# important bit: there is 1 mutation in the bw
self.assertExpectedInline(
bw_graph[0].code.strip(),
"""\
def forward(self, add, tangents_1):
def forward(self, add, clone, tangents_1):
cos = torch.ops.aten.cos.default(clone); clone = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
return (mul_1, None)""",
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
return (mul_2, None)""",
)
def test_fw_bw_mutation_no_functionalization2(self):

View File

@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
op="call_function", target=torch.ops.aten.mm.default
)
self.assertEqual(len(mm_nodes), 4)
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)

View File

@ -4101,53 +4101,6 @@ if HAS_CUDA_AND_TRITON:
compiled_out = compiled_foo(x)
self.assertEqual(eager_out, compiled_out)
# Use autotune_at_compile_time=True to test standalone_compile
@parametrize("autotune_at_compile_time", [True, False])
@config.patch("graph_partition", True)
def test_graph_partition_kernel_reuse(self, autotune_at_compile_time):
def foo(x):
# partition 1
x1 = x @ x
y1 = x1 + 1
z_cpu = y1.cpu() + 1
# partition 2
# partition 2 should reuse the fused triton kernel generated
# in partition 1
x2 = z_cpu.to("cuda") @ z_cpu.to("cuda")
y2 = x2 + 1
return y1, y2
with config.patch(
"triton.autotune_at_compile_time", autotune_at_compile_time
):
compiled_foo = torch.compile(foo)
x = torch.randn((20, 20), device="cuda")
eager_out = foo(x)
compiled_out, code = run_and_get_code(compiled_foo, x)
self.assertEqual(eager_out, compiled_out)
if autotune_at_compile_time:
# auto-tuning block should only appear once. We generate auto-tuning code
# for all the kernels no matter if they are defined in the main graph or
# subgraph, to avoid the overhead of executing multiple auto-tuning code blocks.
FileCheck().check_count(
"Compile-time auto-tuning block", 1, exactly=True
).run(code[0])
# triton_poi_fused_add_ should appear twice, first in the auto-tuning block,
# and then in the main code block
FileCheck().check_count(
"def triton_poi_fused_add_", 2, exactly=True
).run(code[0])
# cpu kernel definition should only appence once, not in the auto-tuning block
FileCheck().check_count(
"cpp_fused__to_copy_add_1 = ", 1, exactly=True
).run(code[0])
else:
# triton_poi_fused_add_ should appear once, because of kernel reuse
FileCheck().check_count(
"def triton_poi_fused_add_", 1, exactly=True
).run(code[0])
def test_meta_tensor(self):
def foobar(x, y):
return x * 2, y * 3

View File

@ -4,9 +4,8 @@ from functools import partial
from unittest import skipIf
import torch
from torch._inductor import config
from torch._inductor.ir import Pointwise
from torch._inductor.lowering import make_fallback, make_pointwise, register_lowering
from torch._inductor.lowering import make_pointwise, register_lowering
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.virtualized import ops
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
@ -238,17 +237,6 @@ class TestCustomLowering(InductorTestCase):
out2 = fn_opt(a, b)
self.assertEqual(out1, out2)
@config.patch(joint_graph_constant_folding=False)
def test_constant_creation(self):
class M(torch.nn.Module):
def forward(self, x):
return x + torch.tensor(1)
make_fallback(torch.ops.aten.lift_fresh_copy.default)
self.assertTrue(
torch.allclose(torch.compile(M())(torch.ones(3)), torch.ones(3) + 1)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -492,36 +492,6 @@ class PackedSequenceTest(TestCase):
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
)
def test_empty_packed_sequence(self):
"""
Regression test for https://github.com/pytorch/pytorch/issues/149622
Tests that pad_packed_sequence and unpack_sequence handle empty tensors
without segmentation fault (CVE-2025-2998, CVE-2025-2999)
"""
# Test case 1: pad_packed_sequence with empty tensors
# Previously caused segmentation fault
empty_data = torch.randn(0, 5)
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
empty_packed = rnn_utils.PackedSequence(
empty_data, empty_batch_sizes, None, None
)
# Should not crash - either return empty result or raise informative error
with self.assertRaises(RuntimeError):
rnn_utils.pad_packed_sequence(empty_packed, batch_first=True)
# Test case 2: unpack_sequence with empty tensors
# Previously caused segmentation fault
empty_data = torch.tensor([])
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
packed = rnn_utils.PackedSequence(
data=empty_data, batch_sizes=empty_batch_sizes
)
# Should not crash - either return empty list or raise informative error
with self.assertRaises(RuntimeError):
rnn_utils.unpack_sequence(packed)
if __name__ == "__main__":
run_tests()

View File

@ -1001,10 +1001,24 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]:
import inspect
if isinstance(mod, torch.nn.Module):
if len(mod._forward_pre_hooks) == 0 and len(mod._forward_hooks) == 0:
# Mirrored from NNModuleVariable.call_function:
# https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/nn_module.py#L1035
if (
len(mod._forward_pre_hooks) == 0
and len(mod._forward_hooks) == 0
and len(torch.nn.modules.module._global_forward_pre_hooks) == 0
and len(torch.nn.modules.module._global_forward_hooks) == 0
and len(mod._backward_pre_hooks) == 0
and len(mod._backward_hooks) == 0
and len(torch.nn.modules.module._global_backward_pre_hooks) == 0
and len(torch.nn.modules.module._global_backward_hooks) == 0
):
mod = mod.forward
elif isinstance(mod, torch.fx.GraphModule):
mod = mod._call_impl
else:
mod = mod.__call__
if hasattr(mod, "__self__"):
# pyrefly: ignore [missing-attribute]
return mod.__func__, mod.__self__

View File

@ -637,7 +637,7 @@ def dynamo_graph_capture_for_export(
pyt.in_shuffle_graph,
pyt.out_shuffle_graph,
tree_leaf_names,
pyt.root,
graph_module if isinstance(pyt.root, torch.nn.Module) else pyt.root,
) # type: ignore[attr-defined]
normalize_graph_module(graph_module)
if pyt.root is not None:
@ -648,6 +648,10 @@ def dynamo_graph_capture_for_export(
graph_module._non_persistent_buffers_set = (
pyt.root._non_persistent_buffers_set.copy()
)
annotations = torch.nn.Module.__dict__.get("__annotations__", None)
for name, value in pyt.root.__dict__.items():
if annotations and name not in annotations:
graph_module.__dict__[name] = value
graph_module._in_spec = pyt.in_spec
graph_module._out_spec = pyt.out_spec
assert not hasattr(graph_module, "_in_shuffle_graph")

View File

@ -2320,8 +2320,6 @@ if sys.version_info >= (3, 11):
torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable
torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable
if sys.version_info >= (3, 13):
torch_c_binding_in_graph_functions["math.fma"] = TorchInGraphFunctionVariable
# In graph functions (including constant folding) that are not C bindings
# NOTE: [Cacheability of in-graph torch functions]

View File

@ -10,10 +10,7 @@ from torch.fx import has_side_effect, Proxy
from .. import graph_break_hints
from ..bytecode_transformation import create_call_function
from ..exc import TYPE_CHECKING, unimplemented
from ..graph_bytecode_inputs import (
get_external_object_by_index,
register_graph_created_object,
)
from ..graph_bytecode_inputs import get_external_object_by_index
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import FxTracebackAnnotateVariable
@ -31,26 +28,6 @@ from torch._library.custom_ops import custom_op
Tensor = torch.Tensor
def new_event(*args: Any, **kwargs: Any) -> int:
event = torch.Event(*args, **kwargs)
return register_graph_created_object(
event,
EventVariable.make_construct_in_graph_event_fn(
TupleVariable([]), ConstDictVariable({})
),
)
def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
return register_graph_created_object(
stream,
StreamVariable.make_construct_in_graph_stream_fn(
TupleVariable([]), ConstDictVariable({})
),
)
def _get_stream_by_index(index: int) -> torch.Stream:
stream = get_external_object_by_index(index)
assert isinstance(stream, torch.Stream), (
@ -138,24 +115,6 @@ def _(
has_side_effect(torch.ops.streams.wait_event.default)
@custom_op("streams::wait_stream", mutates_args=())
def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
waiting = _get_stream_by_index(waiting_stream_index)
waited_on = _get_stream_by_index(waited_on_stream_index)
waiting.wait_stream(waited_on)
@wait_stream.register_fake
def _(
event_index: int,
stream_index: int,
) -> None:
pass
has_side_effect(torch.ops.streams.wait_stream.default)
class SymbolicStreamState:
"""Track the currently entered stream if any"""

View File

@ -603,21 +603,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
VariableTracker.build(tx, polyfills.radians), args, kwargs
)
if hasattr(math, "fma"): # Python 3.13+
@register(math.fma)
def handle_fma(self, tx: "InstructionTranslator", *args, **kwargs):
if len(args) != 3 or kwargs:
return None
if all(isinstance(arg, variables.TensorVariable) for arg in args):
x, y, z = args
addcmul_fn = TorchInGraphFunctionVariable(torch.addcmul)
return addcmul_fn.call_function(tx, [z, x, y], {})
# Use math.fma if constants
return None
@register(torch.is_inference_mode_enabled)
def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"):
unimplemented(

View File

@ -33,6 +33,9 @@ error_on_lifted_constant_tensors = True
# being ready to handle auto_functionalized_v2.
enable_auto_functionalized_v2_for_export = not is_fbcode()
use_legacy_dynamo_graph_capture = True
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -27,6 +27,7 @@ from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
_proxy_tensor_disable_update_tensor_tracker,
get_proxy_mode,
maybe_disable_thunkify,
maybe_enable_thunkify,
)
@ -295,6 +296,10 @@ def create_joint(
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
fn, primals
)
mode = get_proxy_mode()
assert mode is not None
for node in mode.tracer.graph.nodes:
node.meta["partitioner_tag"] = "is_forward"
# TODO: I think this hook can also be eliminated now
if joint_fn_handle and joint_fn_handle.post_forward:

View File

@ -51,6 +51,7 @@ from ._activation_checkpointing.knapsack import (
)
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
from ._aot_autograd.functional_utils import assert_functional_graph
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
@ -297,6 +298,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_backward"
def _has_tag_is_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "is_forward"
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
@ -1021,105 +1026,95 @@ def default_partition(
Returns:
Returns the generated forward and backward Fx graph modules.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(
joint_module,
_joint_inputs,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_indices=static_lifetime_input_indices,
)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
# Respect the original placement of ops rather than rely on dataflow.
forward_nodes = []
last_node = None
for node in joint_module.graph.nodes:
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
last_node = node
assert last_node is not None
for node in joint_module.graph.nodes:
if not _is_tangent(node):
forward_nodes.append(node)
if node is last_node:
break
forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output"
node.name for node in forward_nodes if node.op != "output"
)
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
assert_functional_graph(joint_module.graph)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
saved_values = []
saved_sym_nodes = []
def is_mutated_later_in_fw(node):
if _has_tag_is_backward(node):
return False
tensor_arg_aliases = [
x
for x in node.args
if isinstance(x, fx.Node)
and "val" in x.meta
and isinstance(x.meta["val"], torch.Tensor)
]
while len(tensor_arg_aliases) > 0:
a = tensor_arg_aliases.pop()
for u in a.users:
if not isinstance(u.target, torch._ops.OpOverload):
continue
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
if (
# one of the args was mutated
u.target._schema.is_mutable
# and the mutation happens "later"
and order[u] > order[node]
# and the mutation happened during the forward
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
):
for idx, alias_info in enumerate(u.target._schema.arguments):
if alias_info.is_write and u.args[idx] is a:
return True
elif u.target.is_view:
tensor_arg_aliases.append(u)
return False
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
# if a node isn't "required" to be in the forward, but any of its arguments
# are later mutated in the forward, then it must have been run in the forward
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
if is_mutated_later_in_fw(node):
saved_values.append(node)
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif (
continue
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
saved_values.append(node)
continue
if node.is_impure(impure_random=False) and node.op not in (
"placeholder",
"output",
):
# See is_impure in torch/fx/node.py
assert not graph_has_recomputable_ops, (
"Trying to apply AC on a graph with impure op",
node,
node.target,
)
saved_values.append(node)
continue
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
continue
if (
"tensor_meta" not in node.meta
and node.op == "call_function"
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
):
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target is operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [
n for n in node.users if n.name not in forward_node_names
]
if "tensor_meta" in node.meta and all(
is_sym_node(n) for n in backward_usages
):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
else:
saved_values.append(node)
assert all(user.target == operator.getitem for user in node.users)
continue
if not must_recompute(node):
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
return _extract_fwd_bwd_modules(
if config._sync_decision_cross_ranks:
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
if static_lifetime_input_nodes is None:
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
@ -1127,6 +1122,24 @@ def default_partition(
static_lifetime_input_nodes=static_lifetime_input_nodes,
)
if graph_has_recomputable_ops:
if graph_has_recomputable_rng_ops:
fw_module, bw_module = functionalize_rng_ops(
joint_module, fw_module, bw_module, len(saved_sym_nodes)
)
bw_module = reordering_to_mimic_autograd_engine(bw_module)
# raise all getitem ops to as early as possible
# this is helpful for memory, especially in the case of aot_eager backend
fw_module = raise_getitems(fw_module)
bw_module = raise_getitems(bw_module)
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
if len(node_info.required_bw_nodes) > 0:
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
return fw_module, bw_module
INT_INF = int(1e6)
@ -1621,7 +1634,9 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
break
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
def cleanup_recompute_tags(
joint_module: fx.GraphModule, *, is_default_partition: bool
) -> fx.GraphModule:
"""
If there are two consecutive checkpointed blocks with no operator in
between, we would still want to stash the tensor at the boundary of
@ -1658,6 +1673,16 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
# in forward graph outputs. With this, we can break the above circular dependency.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
elif (
"ac_graph_id" not in node.meta
and any(must_recompute(user) for user in node.users)
and is_default_partition
):
# This node is not part of the AC region and a user is marked as recompute.
# This means it's an input to the AC region and we should save it.
# For ease of landing, gate this to default partitioner only, but we should think
# about flipping the switch in general as well.
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
return joint_module
@ -2765,6 +2790,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
return module
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
def min_cut_rematerialization_partition(
joint_module: fx.GraphModule,
_joint_inputs,
@ -2813,68 +2891,16 @@ def min_cut_rematerialization_partition(
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
if graph_has_recomputable_ops:
joint_module = cleanup_recompute_tags(joint_module)
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
if not config.unsafe_allow_optimization_of_collectives:
force_save_collectives(joint_module)
force_save_bw_mutation_src(joint_module)
def classify_nodes(joint_module, static_lifetime_input_indices):
name_to_node = get_name_to_node(joint_module.graph)
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
for node in joint_module.graph.nodes:
if node.op == "placeholder" and "tangents" in node.target:
required_bw_nodes.add(node)
elif _must_be_in_backward(node):
required_bw_nodes.add(node)
if node in required_bw_nodes:
required_bw_nodes.update(node.users)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
)
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
)
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
node
for node in joint_module.graph.nodes
if node not in required_fw_nodes and node not in required_bw_nodes
)
static_lifetime_input_nodes = OrderedSet(
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
)
fw_cnt = 0
fw_order = {}
for node in joint_module.graph.nodes:
if node in required_fw_nodes:
fw_order[node] = fw_cnt
fw_cnt += 1
return NodeInfo(
inputs,
required_fw_nodes,
required_bw_nodes,
unclaimed_nodes,
fw_order,
static_lifetime_input_nodes,
)
if static_lifetime_input_indices is None:
static_lifetime_input_indices = []
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
node_info = classify_nodes(
joint_module, static_lifetime_input_indices, num_fwd_outputs
)
# networkx blows up on graphs with no required backward nodes
# Since there's nothing to partition anyway, and the default partitioner can "handle"

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
num_warps={self.num_warps},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -2259,7 +2259,7 @@ class PythonWrapperCodegen(CodeGen):
gpu: bool = True,
cpp_definition: Optional[str] = None,
):
if config.triton.autotune_at_compile_time and gpu:
if config.triton.autotune_at_compile_time:
body = self._format_kernel_definition(
kernel_name, kernel_body, metadata=metadata
)
@ -3745,13 +3745,6 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
super().__init__()
root = self.get_root_graph()
# Only generate auto-tuning block in the main graph
self.kernel_autotune_defs = root.kernel_autotune_defs
self.kernel_autotune_calls = root.kernel_autotune_calls
# Only store kernel src to name mapping in the main graph
self.src_to_kernel = root.src_to_kernel
def set_launcher_fn_name(self) -> None:
# This sets up the name of the function containing the launcher code of
# the subgraph.
@ -3844,16 +3837,3 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
# )
self.parent_wrapper.write_get_raw_stream_header_once()
@cache_on_self
def get_root_graph(self) -> PythonWrapperCodegen:
root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self
while isinstance(root, SubgraphPythonWrapperCodegen):
root = root.parent_wrapper
assert isinstance(root, PythonWrapperCodegen)
return root
def generate_and_run_autotune_block(self):
# Only execute auto-tuning block in the main graph
pass

View File

@ -64,7 +64,6 @@ from torch.fx.experimental.symbolic_shapes import (
)
from torch.fx.node import Node
from torch.utils._ordered_set import OrderedSet
from torch.utils._python_dispatch import _disable_current_modes
from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing
from torch.utils._sympy.symbol import SymT
@ -6136,12 +6135,9 @@ class ExternKernel(InputsKernel):
if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
return ShapeAsConstantBuffer(expr=x)
if isinstance(x, Constant):
# We need to unset fake mode, or else the torch.tensor() call will
# turn into a FakeTensor
with _disable_current_modes():
return V.graph.add_tensor_constant(
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
)
return V.graph.add_tensor_constant(
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
)
if isinstance(x, ConstantBuffer):
return x
if isinstance(x, TensorBox):

View File

@ -7099,19 +7099,13 @@ def sym_constrain_range(a, min=None, max=None):
@register_lowering(aten.sym_size.int)
def sym_size(a, dim):
val = V.graph.current_node.meta["val"]
if isinstance(val, torch.SymInt):
return val.node.expr
else:
return int(val)
return val.node.expr
@register_lowering(aten.sym_stride.int)
def sym_stride(a, dim):
val = V.graph.current_node.meta["val"]
if isinstance(val, torch.SymInt):
return val.node.expr
else:
return int(val)
return val.node.expr
@register_lowering(aten.sym_numel)

View File

@ -3607,13 +3607,24 @@ def user_autotune(
)
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
def foreach(triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -702,7 +702,7 @@ def exp2(a):
# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promoting_args=("a,"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:

View File

@ -1,5 +1,6 @@
#pragma once
#include <c10/util/Exception.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/csrc/stable/device_struct.h>
@ -119,7 +120,7 @@ struct FromImpl<ScalarType> {
case ScalarType::UInt64:
return from(aoti_torch_dtype_uint64());
default:
STD_TORCH_CHECK(
TORCH_CHECK(
false,
"Not yet supported ScalarType, please file an issue describing your use case.");
}
@ -150,7 +151,7 @@ struct FromImpl<DeviceType> {
case DeviceType::PrivateUse1:
return from(aoti_torch_device_type_privateuse1());
default:
STD_TORCH_CHECK(
TORCH_CHECK(
false,
"Not yet supported DeviceType, please file an issue describing your use case.");
}
@ -378,7 +379,7 @@ struct ToImpl<ScalarType> {
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
return ScalarType::UInt64;
} else {
STD_TORCH_CHECK(
TORCH_CHECK(
false,
"Not yet supported ScalarType ",
std::to_string(shim_scalartype),
@ -408,7 +409,7 @@ struct ToImpl<DeviceType> {
} else if (shim_devicetype == aoti_torch_device_type_privateuse1()) {
return DeviceType::PrivateUse1;
} else {
STD_TORCH_CHECK(
TORCH_CHECK(
false,
"Not yet supported DeviceType ",
std::to_string(shim_devicetype),

View File

@ -2,7 +2,7 @@ from collections.abc import Callable
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, NamedTuple, Optional, TypeVar, Union
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
import torch
@ -17,9 +17,6 @@ from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary, weakref
if TYPE_CHECKING:
from torch.utils.hooks import RemovableHandle
_TOTAL_KEY = "Total"
__all__ = ["FSDPMemTracker"]
@ -368,28 +365,14 @@ class FSDPMemTracker(MemTracker):
# `FSDPParamGroup.post_forward` because during AC these won't be called.
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
# get the unique _MultiHandlers/RemoveHandlers and store in dictionary
# the _MultiHandlers object will only need to be grabbed once.
unique_handlers: dict[RemovableHandle, bool] = {}
# pyrefly: ignore # missing-attribute
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
if fsdp_param_group := fsdp_state._fsdp_param_group:
if not unique_handlers.get(fsdp_state._pre_forward_hook_handle):
unique_handlers[fsdp_state._pre_forward_hook_handle] = True
if not unique_handlers.get(fsdp_state._post_forward_hook_handle):
unique_handlers[fsdp_state._post_forward_hook_handle] = True
# call remove on the handles once
for f_hook_handle in unique_handlers.keys():
f_hook_handle.remove()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
if fsdp_param_group := fsdp_state._fsdp_param_group:
self._instrument_fsdp_sharded_params_grads(fsdp_param_group)
fsdp_state._pre_forward_hook_handle.remove()
fsdp_state._post_forward_hook_handle.remove()
fsdp_state._pre_forward_hook_handle = (
# pyrefly: ignore [missing-attribute]
module.register_forward_pre_hook(

View File

@ -194,10 +194,6 @@ else:
_rank_map: Optional[torch.Tensor] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None:
# no-op in OSS, logs API usage metrics in meta-internal runs
torch._C._log_api_usage_once(
"torch.distributed.device_mesh.DeviceMesh.__init__"
)
if mesh is not None:
if _layout is not None or _rank_map is not None:
raise TypeError(
@ -259,13 +255,14 @@ else:
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._rank_map.tolist())
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
@ -296,6 +293,11 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# private field to pre-generate DeviceMesh's hash
self._flatten_rank_map = tuple(self._rank_map.tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""

View File

@ -359,10 +359,6 @@ class ShardingPropagator:
"""
Propagate the sharding for an operator given the op_schema.
"""
# no-op in OSS, logs API usage metrics in meta-internal runs
torch._C._log_api_usage_once(
"torch.distributed.tensor._sharding_prop.ShardingPropagator.propogate_op_sharding_non_cached"
)
# special case op, we don't need to propagate for local
# scalar. TODO: figure out a better way to handle this
if op_schema.op is aten._local_scalar_dense.default:

View File

@ -398,9 +398,6 @@ def load(
Under active development, saved files may not be usable in newer versions
of PyTorch.
.. warning::
:func:`torch.export.load()` uses pickle under the hood to load models. **Never load data from an untrusted source.**
Loads an :class:`ExportedProgram` previously saved with
:func:`torch.export.save <torch.export.save>`.

View File

@ -12,6 +12,7 @@ from collections.abc import Callable
from contextlib import contextmanager, ExitStack, nullcontext
from itertools import chain
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
from unittest import mock
if TYPE_CHECKING:
@ -274,6 +275,24 @@ def _extract_fake_inputs(gm, args, kwargs):
else:
fake_vals.append(node.meta.get("example_value"))
if in_shuffle_graph := getattr(gm, "_in_shuffle_graph", None):
flat_args = pytree.tree_leaves((args, kwargs))
node_map = {
node: i
for i, node in enumerate(
next(iter(reversed(in_shuffle_graph.graph.nodes))).args[0]
)
if node.op == "placeholder"
}
new_fake_inps: list[Any] = []
for i, node in enumerate(
in_shuffle_graph.graph.find_nodes(op="placeholder")[1:]
):
if node in node_map:
new_fake_inps.append(fake_inps[node_map[node]])
else:
new_fake_inps.append(flat_args[i])
fake_inps = new_fake_inps
# We get both because now we might have a combination of symint and tensor
# inputs, and we want to check that the shape env is consistent between
# both. Unfortunately we can't see what fake mode is attached to the shape
@ -798,6 +817,16 @@ def _export_to_torch_ir(
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)
def use_legacy_dynamo_graph_capture() -> bool:
return bool(
constraints # dynamic shape
or dynamic_shapes # dynamic shape
or isinstance(f, torch.fx.GraphModule) # retracing
or preserve_module_call_signature # unflatten
or torch._functorch.config.fake_tensor_propagate_real_tensors # draft
or torch._export.config.use_legacy_dynamo_graph_capture
)
with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)):
try:
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = (
@ -812,11 +841,20 @@ def _export_to_torch_ir(
if torch._export.config.use_new_tracer_experimental:
from torch._dynamo.functional_export import (
_dynamo_graph_capture_for_export,
dynamo_graph_capture_for_export,
)
gm_torch_level = _dynamo_graph_capture_for_export(
f, constraints=constraints, dynamic_shapes=dynamic_shapes
)(*args, **kwargs)
if use_legacy_dynamo_graph_capture():
dynamo_graph_capture = _dynamo_graph_capture_for_export(
f, constraints=constraints, dynamic_shapes=dynamic_shapes
)
else:
dynamo_graph_capture = dynamo_graph_capture_for_export(f)
# We can't serialize entire fake mode yet, so this is to make sure
# things like copy.deepcopy(ep.graph_module) not crash.
# see test_export.py::test_custom_tag_metadata_re_export
# Once we delete the old strict export, we can use
gm_torch_level = dynamo_graph_capture(*args, **kwargs)
# We can't serialize entire fake mode yet, so this is to make sure
# things like copy.deepcopy(ep.graph_module) not crash.
# see test_export.py::test_custom_tag_metadata_re_export
@ -1568,7 +1606,11 @@ def _strict_export(
}
tx = TracingContext(dynamo_fake_mode)
with dynamo_fake_mode, tracing(tx):
with (
dynamo_fake_mode,
tracing(tx),
mock.patch.object(dynamo_fake_mode, "allow_non_fake_inputs", True),
):
aten_export_artifact = _to_aten_func(
gm_torch_level,
# NOTE: graph module expects only positional args

View File

@ -3,7 +3,7 @@ import dataclasses
import inspect
import logging
import sys
from collections import defaultdict, OrderedDict
from collections import defaultdict
from collections.abc import Callable
from enum import auto, Enum
from typing import Any, Optional, TYPE_CHECKING, Union
@ -721,18 +721,7 @@ def _combine_args(f, args, kwargs) -> dict[str, Any]:
else inspect.signature(f)
)
kwargs = kwargs if kwargs is not None else {}
combined_args = signature.bind(*args, **kwargs).arguments
# if `args` is in the key, flatten it into args_0, args_1, ...
if "args" in combined_args:
flattened_args = {f"args_{i}": v for i, v in enumerate(combined_args["args"])}
combined_args = OrderedDict({**combined_args, **flattened_args})
del combined_args["args"]
# flatten kwargs into combined_args
if "kwargs" in combined_args:
for k, v in combined_args["kwargs"].items():
combined_args[k] = v
del combined_args["kwargs"]
return combined_args
return signature.bind(*args, **kwargs).arguments
class ShapesCollection:

View File

@ -1709,8 +1709,11 @@ def _convert_guards_to_code(graph_module):
py_printer = torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter(
shape_env.var_to_sources, lambda s: s.name(), shape_env.var_to_sources
)
return [
ret = [
py_printer.doprint(guard.expr)
for guard in shape_env.guards
if guard.expr.free_symbols.issubset(local_vars)
]
# TODO Figure out how to resolve guards containing weight sizes.
# This is not a big deal as _guards_code is mostly empty today.
return [guard for guard in ret if "L['self']" not in guard]

View File

@ -19,13 +19,8 @@ __all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
"register_flash_attention_impl",
"activate_flash_attention_impl",
"list_flash_attention_impls",
"current_flash_attention_impl",
]
# Note: [SDPA warnings]
# TODO: Consider using this for sdpa regardless of subclasses
# This only effects users of bias subclasses
@ -167,23 +162,3 @@ def _sdpa_kernel_variadic(*backends: SDPBackend):
def _get_flash_version() -> str:
"""This returns the closest matching tag for the flash attention backend"""
return "2.5.7"
from . import _registry
# Re-export registry types and functions for public API
_FlashAttentionImpl = _registry._FlashAttentionImpl
_RegisterFn = _registry._RegisterFn
register_flash_attention_impl = _registry.register_flash_attention_impl
activate_flash_attention_impl = _registry.activate_flash_attention_impl
list_flash_attention_impls = _registry.list_flash_attention_impls
current_flash_attention_impl = _registry.current_flash_attention_impl
register_flash_attention_impl.__module__ = __name__
activate_flash_attention_impl.__module__ = __name__
list_flash_attention_impls.__module__ = __name__
current_flash_attention_impl.__module__ = __name__
# Import built-in implementations to trigger self-registration
from . import _fa4 # noqa: F401

View File

@ -1,444 +0,0 @@
"""UBER PROTOTYPE!!!"""
# mypy: allow-untyped-defs
from __future__ import annotations
import importlib
from dataclasses import dataclass
from functools import cache
from typing import Any, TYPE_CHECKING
from typing_extensions import TypeVarTuple, Unpack
from . import _registry
if TYPE_CHECKING:
from types import ModuleType
import torch
from torch.library import Library
__all__ = [
"register_flash_attention_fa4",
]
_FA4_MODULE_PATH: str | None = None
@dataclass
class _FA4Handle:
library: Library | None
def remove(self) -> None:
self.library = None
@cache
def _get_device_major(device: torch.device) -> int:
major, _ = torch.cuda.get_device_capability(device)
return major
def register_flash_attention_fa4(
module_path: str = "flash_attn.cute.interface",
) -> _FA4Handle:
"""
Register FA4 flash attention kernels with the PyTorch dispatcher.
Args:
module_path: Python module path to the FA4 implementation.
"""
global _FA4_MODULE_PATH
_ = _fa4_import_module(module_path)
_FA4_MODULE_PATH = module_path
return _FA4Handle(_fa4_register_kernels())
@cache
def _fa4_import_module(module_path: str) -> ModuleType:
module = importlib.import_module(module_path)
if not hasattr(module, "_flash_attn_fwd") or not hasattr(module, "_flash_attn_bwd"):
raise RuntimeError(f"Module '{module_path}' does not expose FA4 kernels")
return module
def _fa4_register_kernels() -> Library:
lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901
lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA")
lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA")
lib.impl(
"_scaled_dot_product_flash_attention",
_fa4_scaled_dot_product_flash_attention_forward_impl,
"CUDA",
)
lib.impl(
"_scaled_dot_product_flash_attention_backward",
_fa4_scaled_dot_product_flash_attention_backward_impl,
"CUDA",
)
return lib
def _fa4_common_support_error(
query: torch.Tensor,
tensors: tuple[torch.Tensor, ...],
cum_seq_q: torch.Tensor | None,
require_fp32: tuple[tuple[str, torch.Tensor], ...] = (),
) -> str | None:
if not all(t.is_cuda for t in tensors):
return "inputs must be CUDA tensors"
if len({t.device for t in tensors}) != 1:
return "inputs must share device"
if query.dtype not in (torch.float16, torch.bfloat16):
return "query dtype must be float16 or bfloat16"
for name, tensor in require_fp32:
if tensor.dtype != torch.float32:
return f"{name} dtype must be float32"
if cum_seq_q is None and query.dim() != 4:
return "dense query must be 4D"
if cum_seq_q is not None and query.dim() != 3:
return "ragged query must be 3D"
if not torch.cuda.is_available():
return "CUDA not available"
if _get_device_major(query.device) not in (9, 10):
return "FA4 requires compute capability 9.0 or 10.0"
return None
def _fa4_forward_support_error(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float,
return_debug_mask: bool,
alibi_slopes: torch.Tensor | None,
seqused_k: torch.Tensor | None,
cum_seq_q: torch.Tensor | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if return_debug_mask:
return "return_debug_mask must be False"
if alibi_slopes is not None:
return "alibi_slopes not supported"
if seqused_k is not None:
if seqused_k.dtype != torch.int32:
return "seqused_k must be int32"
if not seqused_k.is_cuda:
return "seqused_k must be CUDA"
error = _fa4_common_support_error(
query,
(query, key, value),
cum_seq_q,
)
if error is not None:
if error == "inputs must share device":
return "query, key, value must be on same device"
return error
return None
def _fa4_backward_support_error(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
dropout_p: float,
cum_seq_q: torch.Tensor | None,
window_size_left: int | None,
window_size_right: int | None,
) -> str | None:
if dropout_p != 0.0:
return "dropout_p must be 0"
if window_size_left is not None or window_size_right is not None:
return "windowed attention not supported"
error = _fa4_common_support_error(
query,
(grad_out, query, key, value, out, logsumexp),
cum_seq_q,
require_fp32=(("logsumexp", logsumexp),),
)
if error is not None:
return error
return None
Ts = TypeVarTuple("Ts")
def _transpose_dense(*tensors: Unpack[Ts]) -> tuple[Unpack[Ts]]:
return tuple(t.transpose(1, 2) for t in tensors) # type: ignore[attr-defined]
def _fa4_run_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
window_size_left: int | None,
window_size_right: int | None,
seqused_k: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
kwargs: dict[str, Any] = {
"softmax_scale": scale,
"causal": is_causal,
"window_size_left": window_size_left,
"window_size_right": window_size_right,
"return_lse": True,
"cu_seqlens_q": cu_seq_q,
"cu_seqlens_k": cu_seq_k,
"seqused_k": seqused_k.contiguous() if seqused_k is not None else None,
}
out, lse = module._flash_attn_fwd(query, key, value, **kwargs)
return out, lse.contiguous()
def _fa4_run_backward(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cu_seq_q: torch.Tensor | None,
cu_seq_k: torch.Tensor | None,
scale: float | None,
is_causal: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if _FA4_MODULE_PATH is None:
raise RuntimeError("FA4 not registered")
module = _fa4_import_module(_FA4_MODULE_PATH)
dq, dk, dv = module._flash_attn_bwd(
query,
key,
value,
out,
grad_out,
logsumexp.contiguous(),
softmax_scale=scale,
causal=is_causal,
cu_seqlens_q=cu_seq_q,
cu_seqlens_k=cu_seq_k,
)
return dq, dk, dv
def _fa4_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
return_debug_mask: bool,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
seqused_k: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
alibi_slopes,
seqused_k,
cum_seq_q,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}")
out, lse = _fa4_run_forward(
query,
key,
value,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
window_size_left,
window_size_right,
seqused_k,
)
rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device)
philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device)
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
return out, lse, rng_state, philox_offset, debug_mask
def _fa4_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
rng_state: torch.Tensor,
unused: torch.Tensor,
*,
scale: float | None = None,
window_size_left: int | None = None,
window_size_right: int | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
cum_seq_q,
window_size_left,
window_size_right,
)
if error is not None:
raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}")
dq, dk, dv = _fa4_run_backward(
grad_out,
query,
key,
value,
out,
logsumexp,
cum_seq_q,
cum_seq_k,
scale,
is_causal,
)
return dq, dk, dv
def _fa4_scaled_dot_product_flash_attention_forward_impl(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: float | None = None,
):
error = _fa4_forward_support_error(
query,
key,
value,
dropout_p,
return_debug_mask,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA forward unsupported: {error}")
q, k, v = _transpose_dense(query, key, value)
max_q_flash = q.size(1)
max_k_flash = k.size(1)
out, lse, rng_state, philox_offset, debug_mask = _fa4_flash_attention_forward_impl(
q,
k,
v,
None,
None,
max_q_flash,
max_k_flash,
dropout_p,
is_causal,
return_debug_mask,
scale=scale,
)
(out,) = _transpose_dense(out)
max_q = query.size(2)
max_k = key.size(2)
return (
out,
lse,
None,
None,
max_q,
max_k,
rng_state,
philox_offset,
debug_mask,
)
def _fa4_scaled_dot_product_flash_attention_backward_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor | None,
cum_seq_k: torch.Tensor | None,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
*,
scale: float | None = None,
):
error = _fa4_backward_support_error(
grad_out,
query,
key,
value,
out,
logsumexp,
dropout_p,
None,
None,
None,
)
if error is not None:
raise RuntimeError(f"FA4 SDPA backward unsupported: {error}")
q, k, v, o, go = _transpose_dense(query, key, value, out, grad_out)
max_q = query.size(2)
max_k = key.size(2)
dq, dk, dv = _fa4_flash_attention_backward_impl(
go,
q,
k,
v,
o,
logsumexp,
None,
None,
max_q,
max_k,
dropout_p,
is_causal,
philox_seed,
philox_offset,
scale=scale,
)
dq, dk, dv = _transpose_dense(dq, dk, dv)
return dq, dk, dv
_registry.register_flash_attention_impl("FA4", register_fn=register_flash_attention_fa4)

View File

@ -1,108 +0,0 @@
# mypy: allow-untyped-defs
"""Registry for flash attention implementations.
This module contains the registration system for flash attention implementations.
It has no torch dependencies to avoid circular imports during initialization.
"""
from typing import Callable, Literal, Protocol
class FlashAttentionHandle(Protocol):
def remove(self) -> None: ...
_RegisterFn = Callable[..., FlashAttentionHandle | None]
_FlashAttentionImpl = Literal["FA4"]
_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
_FLASH_ATTENTION_ACTIVE: str | None = None
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}
def register_flash_attention_impl(
impl: str | _FlashAttentionImpl,
*,
register_fn: _RegisterFn,
) -> None:
"""
Register the callable that activates a flash attention impl.
.. note::
This function is intended for SDPA backend providers to register their
implementations. End users should use :func:`activate_flash_attention_impl`
to activate a registered implementation.
Args:
impl: Implementation identifier (e.g., ``"FA4"``).
register_fn: Callable that performs the actual dispatcher registration.
This function will be invoked by :func:`activate_flash_attention_impl`
and should register custom kernels with the PyTorch dispatcher.
It may optionally return a handle implementing
:class:`FlashAttentionHandle` to keep any necessary state alive.
Example:
>>> def my_impl_register(module_path: str = "my_flash_impl"):
... # Register custom kernels with torch dispatcher
... pass # doctest: +SKIP
>>> register_flash_attention_impl(
... "MyImpl", register_fn=my_impl_register
... ) # doctest: +SKIP
"""
_FLASH_ATTENTION_IMPLS[impl] = register_fn
def activate_flash_attention_impl(
impl: str | _FlashAttentionImpl,
) -> None:
"""
Activate into the dispatcher a previously registered flash attention impl.
.. note::
Backend providers should NOT automatically activate their implementation
on import. Users should explicitly opt-in by calling this function or via
environment variables to ensure multiple provider libraries can coexist.
Args:
impl: Implementation identifier to activate. See
:func:`~torch.nn.attention.list_flash_attention_impls` for available
implementations.
If the backend's :func:`register_flash_attention_impl` callable
returns a :class:`FlashAttentionHandle`, the registry keeps that
handle alive for the lifetime of the process (until explicit
uninstall support exists).
Example:
>>> activate_flash_attention_impl("FA4") # doctest: +SKIP
"""
global _FLASH_ATTENTION_ACTIVE
register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
if register_fn is None:
raise ValueError(
f"Unknown flash attention impl '{impl}'. "
f"Available implementations: {list_flash_attention_impls()}"
)
# TODO: The only way to actually register a new impl is to unregister the current impl
# reinstall the default impl and then register the new impl
if _FLASH_ATTENTION_ACTIVE == impl:
return
handle = register_fn()
if handle is not None:
_FLASH_ATTENTION_HANDLES[impl] = handle
_FLASH_ATTENTION_ACTIVE = impl
def list_flash_attention_impls() -> list[str]:
"""Return the names of all available flash attention implementations."""
return sorted(_FLASH_ATTENTION_IMPLS.keys())
def current_flash_attention_impl() -> str | None:
"""
Return the currently activated flash attention impl name, if any.
``None`` indicates that no custom impl has been activated.
"""
return _FLASH_ATTENTION_ACTIVE