mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 09:34:51 +08:00
Compare commits
32 Commits
zhxchen17/
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 5c30fb7061 | |||
| e8d411e7f7 | |||
| 2e5233d7bd | |||
| 514dd96376 | |||
| 9ae62fcc18 | |||
| ae71b0e163 | |||
| 5b6ff8148d | |||
| 1f7e4343e7 | |||
| b21856f5fc | |||
| 259ba0ecab | |||
| 051f1fe8e3 | |||
| ee387c43fe | |||
| 3a944661d6 | |||
| 56034074ca | |||
| 8def619bbe | |||
| 61883a5787 | |||
| d8ada1ee76 | |||
| fe841a1db4 | |||
| b65829b84f | |||
| b0e0ae97ba | |||
| f44a1ddcb2 | |||
| 184e2cbc89 | |||
| 416421c7c4 | |||
| bd99ae3315 | |||
| ce8672c24f | |||
| 402c465030 | |||
| 573a79fffa | |||
| 4945180468 | |||
| 1df723e6f5 | |||
| f9b81e23e4 | |||
| ffe6cc39c7 | |||
| db1f3f6901 |
@ -1 +1 @@
|
||||
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
|
||||
1070cd530573098dc8375422a1f1918d3815af3f
|
||||
|
||||
@ -96,7 +96,6 @@ function pip_build_and_install() {
|
||||
python3 -m pip wheel \
|
||||
--no-build-isolation \
|
||||
--no-deps \
|
||||
--no-use-pep517 \
|
||||
-w "${wheel_dir}" \
|
||||
"${build_target}"
|
||||
fi
|
||||
|
||||
2
.github/actionlint.yaml
vendored
2
.github/actionlint.yaml
vendored
@ -63,7 +63,7 @@ self-hosted-runner:
|
||||
- linux.rocm.gpu.gfx942.1
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
- rocm-docker
|
||||
- linux.rocm.gfx942.docker-cache
|
||||
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
|
||||
- macos-m1-stable
|
||||
- macos-m1-14
|
||||
|
||||
55
.github/workflows/docker-cache-mi300.yml
vendored
55
.github/workflows/docker-cache-mi300.yml
vendored
@ -1,55 +0,0 @@
|
||||
name: docker-cache-mi300
|
||||
|
||||
on:
|
||||
# run every 6 hours
|
||||
schedule:
|
||||
- cron: 0 0,6,12,18 * * *
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
runs-on: rocm-docker
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
|
||||
with:
|
||||
no-sudo: true
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
with:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
push: false
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
|
||||
- name: Tar and upload to S3 bucket
|
||||
run: |
|
||||
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress
|
||||
108
.github/workflows/docker-cache-rocm.yml
vendored
Normal file
108
.github/workflows/docker-cache-rocm.yml
vendored
Normal file
@ -0,0 +1,108 @@
|
||||
name: docker-cache-rocm
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: [docker-builds]
|
||||
# TODO: Uncomment before merging
|
||||
#branches: [main, release]
|
||||
types:
|
||||
- completed
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
actions: read
|
||||
|
||||
jobs:
|
||||
download-docker-builds-artifacts:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: download-docker-builds-artifacts
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
|
||||
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
|
||||
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
|
||||
steps:
|
||||
- name: Download artifacts
|
||||
uses: actions/download-artifact@v4.1.7
|
||||
with:
|
||||
run-id: ${{ github.event.workflow_run.id }}
|
||||
path: ./docker-builds-artifacts
|
||||
merge-multiple: true
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Process artifacts
|
||||
id: process-artifacts
|
||||
run: |
|
||||
ls -R ./docker-builds-artifacts
|
||||
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
|
||||
cat "${GITHUB_OUTPUT}"
|
||||
|
||||
docker-cache:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
needs: download-docker-builds-artifacts
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
runner: [linux.rocm.gfx942.docker-cache]
|
||||
docker-image: [
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
|
||||
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
|
||||
]
|
||||
runs-on: "${{ matrix.runner }}"
|
||||
steps:
|
||||
- name: debug
|
||||
run: |
|
||||
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
|
||||
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
|
||||
|
||||
- name: configure aws credentials
|
||||
id: aws_creds
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
aws-region: us-east-1
|
||||
role-duration-seconds: 18000
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
continue-on-error: false
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
|
||||
- name: Generate ghrc.io tag
|
||||
id: ghcr-io-tag
|
||||
run: |
|
||||
ecr_image="${{ matrix.docker-image }}"
|
||||
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
|
||||
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
with:
|
||||
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
|
||||
|
||||
- name: Save as tarball
|
||||
run: |
|
||||
docker_image_tag=${{ matrix.docker-image }}
|
||||
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
|
||||
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
|
||||
ref_name=${{ github.event.workflow_run.head_branch }}
|
||||
if [[ $ref_name =~ "release/" ]]; then
|
||||
ref_suffix="release"
|
||||
elif [[ $ref_name == "main" ]]; then
|
||||
ref_suffix="main"
|
||||
else
|
||||
# TODO: Remove below
|
||||
ref_suffix="main"
|
||||
# echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
|
||||
fi
|
||||
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
|
||||
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
|
||||
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
|
||||
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar
|
||||
@ -142,6 +142,7 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra
|
||||
std::tuple<Tensor, Tensor> _pad_packed_sequence(const Tensor& data, const Tensor& _batch_sizes, bool batch_first, const Scalar& padding_value, int64_t total_length) {
|
||||
auto batch_sizes_t = _batch_sizes.contiguous();
|
||||
checkLongTensor(batch_sizes_t);
|
||||
TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty");
|
||||
|
||||
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
|
||||
int64_t max_batch_size = batch_sizes[0];
|
||||
|
||||
@ -669,9 +669,12 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
// On non CK system(w/ ROCm), make sure use_fast_path is false
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
@ -680,7 +683,11 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
TORCH_WARN("ROCm: Group Gemm through CK not selected.");
|
||||
#endif //USE_ROCM_CK_GEMM
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
|
||||
@ -47,6 +47,7 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/distance.h>
|
||||
#include <thrust/for_each.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <thrust/gather.h>
|
||||
|
||||
@ -50,7 +50,7 @@ nfnet_l0,pass,7
|
||||
|
||||
|
||||
|
||||
repvgg_a2,fail_accuracy,7
|
||||
repvgg_a2,pass,7
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ tp2_dir="$top_dir/third_party"
|
||||
pip install ninja
|
||||
|
||||
# Install onnx
|
||||
pip install --no-use-pep517 -e "$tp2_dir/onnx"
|
||||
pip install -e "$tp2_dir/onnx"
|
||||
|
||||
# Install caffe2 and pytorch
|
||||
pip install -r "$top_dir/caffe2/requirements.txt"
|
||||
|
||||
@ -140,6 +140,11 @@ static void initDeviceStreamState(DeviceIndex device_index) {
|
||||
static void initOpenRegStreamsOnce() {
|
||||
c10::call_once(init_flag, initGlobalStreamState);
|
||||
|
||||
for (const auto i : c10::irange(num_devices)) {
|
||||
c10::call_once(
|
||||
device_flags[i], initDeviceStreamState, static_cast<DeviceIndex>(i));
|
||||
}
|
||||
|
||||
if (current_streams) {
|
||||
return;
|
||||
}
|
||||
@ -202,8 +207,6 @@ OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) {
|
||||
if (device_index == -1) {
|
||||
device_index = current_device();
|
||||
}
|
||||
c10::call_once(
|
||||
device_flags[device_index], initDeviceStreamState, device_index);
|
||||
auto pri_idx =
|
||||
std::clamp(priority, 0, max_compile_time_stream_priorities - 1);
|
||||
const auto idx = get_idx(priority_counters[device_index][pri_idx]);
|
||||
|
||||
@ -180,6 +180,47 @@ class TestTrackerFullyShard1DTrainingCore(FSDPTest):
|
||||
del model
|
||||
del optim
|
||||
|
||||
def _test_tracker_multihandler_hook(self):
|
||||
"""Should run without KeyError."""
|
||||
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.norm1 = nn.RMSNorm(dim)
|
||||
self.output1 = nn.Linear(dim, dim)
|
||||
self.norm2 = nn.RMSNorm(dim)
|
||||
self.output2 = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.norm1(x)
|
||||
x = self.output1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.output2(x)
|
||||
return x
|
||||
|
||||
gc.collect()
|
||||
torch.manual_seed(42)
|
||||
dev = torch.device(torch.accelerator.current_device_index())
|
||||
|
||||
with torch.device(dev):
|
||||
model = TestModule(128)
|
||||
|
||||
mesh = init_device_mesh(dev.type, (self.world_size,))
|
||||
fully_shard([model.norm1, model.output1], mesh=mesh)
|
||||
fully_shard([model.norm2, model.output2], mesh=mesh)
|
||||
fully_shard(model, mesh=mesh)
|
||||
|
||||
fmt = FSDPMemTracker(model)
|
||||
|
||||
with fmt:
|
||||
inp = torch.randn(16, 128, device=dev)
|
||||
y = model(inp)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
del inp
|
||||
del model
|
||||
|
||||
|
||||
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
|
||||
@property
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -371,6 +372,7 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
# aot_export_joint_with_descriptors on strict-exported exported_program.module()
|
||||
# is producing a joint graph with backward region missing
|
||||
@unittest.expectedFailure
|
||||
def test_strict_export_parallelize_module_with_dtensor_input(self):
|
||||
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ import torch._functorch.config
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import default_partition, min_cut_rematerialization_partition
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
@ -24,7 +24,7 @@ from torch._dynamo.testing import (
|
||||
)
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
@ -281,14 +281,7 @@ class ActivationCheckpointingViaTagsTests(
|
||||
|
||||
run(export_compiler)
|
||||
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function(self, device, partition_fn):
|
||||
def test_tags_function(self, device):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -304,22 +297,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_via_global_checkpoint(self, device, partition_fn):
|
||||
def test_tags_function_via_global_checkpoint(self, device):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -334,28 +316,17 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_function_with_kwargs(self, device, partition_fn):
|
||||
def test_tags_function_with_kwargs(self, device):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
def fn(x, y):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
gn, torch.sin(x), y, use_reentrant=False
|
||||
gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
|
||||
)
|
||||
|
||||
x = torch.randn(4, 4, device=device, requires_grad=True)
|
||||
@ -365,22 +336,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=3, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_sequential_layers(self, device, partition_fn):
|
||||
def test_tags_sequential_layers(self, device):
|
||||
def gn(x):
|
||||
x = x.cos()
|
||||
for _ in range(3):
|
||||
@ -401,22 +361,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
freqs=[2, 18],
|
||||
ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_multiple_checkpoints(self, device, partition_fn):
|
||||
def test_tags_multiple_checkpoints(self, device):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y))
|
||||
|
||||
@ -434,22 +383,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=6, op=torch.ops.aten.mm.default
|
||||
) # mm recomputed in the bwd
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_module(self, device, partition_fn):
|
||||
def test_tags_module(self, device):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -473,22 +411,11 @@ class ActivationCheckpointingViaTagsTests(
|
||||
bw_compiler = functools.partial(
|
||||
count_ops, freq=1, op=torch.ops.aten.sigmoid.default
|
||||
)
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_tags_decomps(self, device, partition_fn):
|
||||
def test_tags_decomps(self, device):
|
||||
# Ensures that tags are passed on through decompositions as well
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -516,7 +443,6 @@ class ActivationCheckpointingViaTagsTests(
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=lambda: import_module(
|
||||
"torch._inductor.compile_fx"
|
||||
).select_decomp_table(),
|
||||
@ -776,14 +702,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_must_recompute(self, device):
|
||||
def context_fn_must_recompute_mm():
|
||||
must_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -804,9 +723,9 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
),
|
||||
)
|
||||
|
||||
def _test(context_fn, bw_compiler, partition_fn):
|
||||
def _test(context_fn, bw_compiler):
|
||||
def gn(x):
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x))
|
||||
return torch.sigmoid(torch.matmul(x, x))
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
@ -820,14 +739,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
fw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freq=2,
|
||||
freq=1,
|
||||
op=torch.ops.aten.mm.default,
|
||||
)
|
||||
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x)
|
||||
|
||||
@ -835,19 +754,17 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
context_fn=context_fn_must_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6)
|
||||
freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
_test(
|
||||
context_fn=context_fn_no_recompute_mm,
|
||||
bw_compiler=functools.partial(
|
||||
count_ops,
|
||||
freq=4, # 2 bwd mm ops per fwd matmul
|
||||
freq=2, # 2 bwd mm ops per fwd matmul
|
||||
op=torch.ops.aten.mm.default,
|
||||
),
|
||||
partition_fn=partition_fn,
|
||||
)
|
||||
|
||||
def test_sac_with_partial_context_fn(self):
|
||||
@ -884,16 +801,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(
|
||||
self, device, partition_fn
|
||||
):
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -933,22 +841,15 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization(
|
||||
self, device, partition_fn
|
||||
self, device
|
||||
):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
@ -988,7 +889,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
disable_functionalization=True,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
@ -996,14 +897,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_triton_kernel(self, device):
|
||||
# Copy of the above test, but make sure that having a triton kernel in the
|
||||
# region does not error.
|
||||
def add_one(x):
|
||||
@ -1063,21 +957,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_tensor_subclass(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1120,21 +1007,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_custom_rule(self, device):
|
||||
def _get_custom_policy(meta):
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1192,21 +1072,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_partial_ctx_fn(self, device):
|
||||
def selective_checkpointing_context_fn(no_recompute_list):
|
||||
return create_selective_checkpoint_contexts(
|
||||
_get_custom_policy(no_recompute_list=no_recompute_list)
|
||||
@ -1245,21 +1118,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_outplace_op(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1297,21 +1163,14 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_list_ops(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_list_ops(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
# recompute everything
|
||||
no_recompute_list = []
|
||||
@ -1347,7 +1206,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1358,14 +1217,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
"requires TorchDispatchMode + torch.compile work to complete"
|
||||
)
|
||||
@requires_cuda_and_triton
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_inplace_op(self, device):
|
||||
def selective_checkpointing_context_fn():
|
||||
no_recompute_list = [
|
||||
torch.ops.aten.mm.default,
|
||||
@ -1405,7 +1257,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
self._validate(fn, backend, x, y)
|
||||
self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
|
||||
@ -1413,14 +1265,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@torch._inductor.config.patch(fallback_random=True)
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_random_op(self, device, partition_fn):
|
||||
def test_compile_selective_checkpoint_random_op(self, device):
|
||||
for preserve_rng_state in [True, False]:
|
||||
|
||||
def selective_checkpointing_context_fn():
|
||||
@ -1467,7 +1312,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
|
||||
# NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
|
||||
@ -1479,14 +1324,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_invalid_context(self, partition_fn):
|
||||
def test_compile_selective_checkpoint_invalid_context(self):
|
||||
def gn(x, y):
|
||||
return torch.sigmoid(torch.matmul(x, y)) * y
|
||||
|
||||
@ -1515,7 +1353,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
Exception, "must generate a tuple of two `TorchDispatchMode`s"
|
||||
@ -1524,14 +1362,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
|
||||
@parametrize(
|
||||
"partition_fn",
|
||||
[
|
||||
min_cut_rematerialization_partition,
|
||||
default_partition,
|
||||
],
|
||||
)
|
||||
def test_compile_selective_checkpoint_parametrization(self, partition_fn):
|
||||
def test_compile_selective_checkpoint_parametrization(self):
|
||||
def sac_policy():
|
||||
def _recomp_policy():
|
||||
def _custom_policy(ctx, func, *args, **kwargs):
|
||||
@ -1594,9 +1425,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
bw_compiler = functools.partial(
|
||||
count_ops,
|
||||
freqs=[
|
||||
# 1 from mul recompute, 1 from mul backward
|
||||
# w/o CSE, we have one extra mul
|
||||
3 if partition_fn is default_partition else 2,
|
||||
2, # 1 from mul recompute, 1 from mul backward
|
||||
1,
|
||||
],
|
||||
ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
|
||||
@ -1605,7 +1434,7 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
backend = aot_autograd(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=bw_compiler,
|
||||
partition_fn=partition_fn,
|
||||
partition_fn=min_cut_rematerialization_partition,
|
||||
)
|
||||
|
||||
model = MLPModule()
|
||||
|
||||
@ -2363,6 +2363,34 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(same(output, expected))
|
||||
assert cnt.frame_count == 1
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 13), "math.fma introduced in python 3.13")
|
||||
def test_math_fma(self):
|
||||
def fma_func(a, b, c):
|
||||
return math.fma(a, b, c)
|
||||
|
||||
# Test with scalar constants (constant folding path)
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
cfma_scalars = torch._dynamo.optimize_assert(cnt)(fma_func)
|
||||
|
||||
assert cnt.frame_count == 0
|
||||
expected = fma_func(2.0, 3.0, 4.0)
|
||||
output = cfma_scalars(2.0, 3.0, 4.0)
|
||||
self.assertEqual(output, expected)
|
||||
assert cnt.frame_count == 0
|
||||
|
||||
# Test with tensors (Inductor path)
|
||||
cnt2 = torch._dynamo.testing.CompileCounter()
|
||||
cfma_tensors = torch._dynamo.optimize_assert(cnt2)(fma_func)
|
||||
|
||||
assert cnt2.frame_count == 0
|
||||
x = torch.tensor(2.0)
|
||||
y = torch.tensor(3.0)
|
||||
z = torch.tensor(4.0)
|
||||
expected_tensors = x * y + z
|
||||
output_tensors = cfma_tensors(x, y, z)
|
||||
torch.testing.assert_close(output_tensors, expected_tensors)
|
||||
assert cnt2.frame_count == 1
|
||||
|
||||
@make_test
|
||||
def test_numpy_meshgrid(x, y):
|
||||
r1, r2 = np.meshgrid(x.numpy(), y.numpy())
|
||||
|
||||
@ -335,6 +335,59 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@requires_multigpu()
|
||||
def test_new_event_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_event
|
||||
|
||||
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
e0_ind = new_event()
|
||||
with torch.Stream(device="cuda:1"):
|
||||
get_external_object_by_index(e0_ind).record()
|
||||
e1_ind = new_event()
|
||||
self.assertNotEqual(e0_ind, e1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(e0_ind),
|
||||
get_external_object_by_index(e1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=event_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_new_stream_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_stream
|
||||
|
||||
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
s0_ind = new_stream()
|
||||
s1_ind = new_stream()
|
||||
self.assertNotEqual(s0_ind, s1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(s0_ind),
|
||||
get_external_object_by_index(s1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=stream_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
@ -523,6 +576,23 @@ class <lambda>(torch.nn.Module):
|
||||
torch.accelerator.set_stream(original_stream)
|
||||
reset_user_object_tracking()
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck_wait_record_stream(self):
|
||||
from torch._dynamo.variables.streams import wait_stream
|
||||
from torch.library import opcheck
|
||||
|
||||
s0 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s2 = torch.Stream()
|
||||
store_user_object_weakrefs(s0, s1, s2)
|
||||
|
||||
sample_inputs = [
|
||||
(0, 1),
|
||||
(2, 0),
|
||||
]
|
||||
for args in sample_inputs:
|
||||
opcheck(wait_stream, args)
|
||||
|
||||
@requires_cuda
|
||||
def test_inductor_lowering(self):
|
||||
with patch("torch._inductor.config.implicit_fallbacks", False):
|
||||
|
||||
@ -331,7 +331,12 @@ class TestDynamismExpression(TestCase):
|
||||
return torch.ops.aten.slice.Tensor(*args)
|
||||
|
||||
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
|
||||
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
|
||||
dynamic_shapes = (
|
||||
{0: Dim("dim")},
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
torch.export.export(
|
||||
Slice(),
|
||||
inp,
|
||||
@ -5533,21 +5538,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
|
||||
w = Wrapped()
|
||||
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Detected mismatch between the structure of `inputs` and `dynamic_shapes`"
|
||||
": `inputs` has 2 elements, but `dynamic_shapes` has 1 elements",
|
||||
):
|
||||
export(w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})})
|
||||
else:
|
||||
compiled = export(
|
||||
w, args, dynamic_shapes={"args": ({0: batch}, {0: batch})}
|
||||
)
|
||||
expected = w(*args)
|
||||
mod = compiled.module()
|
||||
got = mod(*args)
|
||||
self.assertTrue(torch.allclose(expected, got))
|
||||
compiled = export(w, args, dynamic_shapes=({0: batch}, {0: batch}))
|
||||
expected = w(*args)
|
||||
mod = compiled.module()
|
||||
got = mod(*args)
|
||||
self.assertTrue(torch.allclose(expected, got))
|
||||
|
||||
def test_dynamic_shapes_builder_basic(self):
|
||||
class M(torch.nn.Module):
|
||||
@ -17504,6 +17499,105 @@ def forward(self, x):
|
||||
exported_param_names = [name for name, _ in gm.named_parameters()]
|
||||
self.assertEqual(original_param_names, exported_param_names)
|
||||
|
||||
def test_export_compiled_model_with_nested_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, data_batch):
|
||||
return data_batch["a1"] + data_batch["a2"]
|
||||
|
||||
m = M()
|
||||
compiled_m = torch.compile(m)
|
||||
example_args = (
|
||||
{
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
},
|
||||
)
|
||||
dynamic_shapes = (
|
||||
{
|
||||
"a1": {0: Dim.DYNAMIC},
|
||||
"a2": {0: Dim.DYNAMIC},
|
||||
},
|
||||
)
|
||||
ep = export(
|
||||
compiled_m, example_args, dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(*example_args), compiled_m(*example_args))
|
||||
|
||||
def test_export_model_with_nested_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, data_batch):
|
||||
return data_batch["a1"] + data_batch["a2"]
|
||||
|
||||
m = M()
|
||||
example_args = (
|
||||
{
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
},
|
||||
)
|
||||
B = torch.export.Dim("batch", min=1, max=65536)
|
||||
dynamic_shapes = (
|
||||
{
|
||||
"a1": {0: B},
|
||||
"a2": {0: B},
|
||||
},
|
||||
)
|
||||
ep = export(m, example_args, dynamic_shapes=dynamic_shapes, strict=True)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(*example_args), m(*example_args))
|
||||
|
||||
def test_export_compiled_model_with_kwargs_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a1, a2):
|
||||
return a1 + a2
|
||||
|
||||
m = M()
|
||||
compiled_m = torch.compile(m)
|
||||
example_args = ()
|
||||
example_kwargs = {
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
}
|
||||
dynamic_shapes = {
|
||||
"a1": {0: Dim.DYNAMIC},
|
||||
"a2": {0: Dim.DYNAMIC},
|
||||
}
|
||||
ep = export(
|
||||
compiled_m,
|
||||
example_args,
|
||||
kwargs=example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(**example_kwargs), compiled_m(**example_kwargs))
|
||||
|
||||
def test_export_model_with_kwargs_dynamic_shapes(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a1, a2):
|
||||
return a1 + a2
|
||||
|
||||
m = M()
|
||||
example_args = ()
|
||||
example_kwargs = {
|
||||
"a1": torch.ones(3, 3),
|
||||
"a2": torch.ones(3, 3),
|
||||
}
|
||||
dynamic_shapes = {
|
||||
"a1": {0: Dim.DYNAMIC},
|
||||
"a2": {0: Dim.DYNAMIC},
|
||||
}
|
||||
ep = export(
|
||||
m,
|
||||
example_args,
|
||||
kwargs=example_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
gm = ep.module()
|
||||
self.assertEqual(gm(**example_kwargs), m(**example_kwargs))
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||
class TestExportCustomClass(TorchTestCase):
|
||||
|
||||
@ -2640,7 +2640,7 @@ def forward(self, primals_1, primals_2):
|
||||
return grad_output * x, grad_output * x
|
||||
|
||||
def f(a, b):
|
||||
return FwBwMutation.apply(a, b).sin_().clone()
|
||||
return FwBwMutation.apply(a, b)
|
||||
|
||||
inps = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
@ -2689,22 +2689,17 @@ def forward(self, primals_1, primals_2):
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
_foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
|
||||
clone = torch.ops.aten.clone.default(mul)
|
||||
sin_ = torch.ops.aten.sin_.default(mul); mul = None
|
||||
clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None
|
||||
return (clone_1, add, clone)""",
|
||||
return (mul, add)""",
|
||||
)
|
||||
|
||||
# important bit: there is 1 mutation in the bw
|
||||
self.assertExpectedInline(
|
||||
bw_graph[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, add, clone, tangents_1):
|
||||
cos = torch.ops.aten.cos.default(clone); clone = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
|
||||
def forward(self, add, tangents_1):
|
||||
_foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None
|
||||
mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None
|
||||
return (mul_2, None)""",
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None
|
||||
return (mul_1, None)""",
|
||||
)
|
||||
|
||||
def test_fw_bw_mutation_no_functionalization2(self):
|
||||
|
||||
@ -927,8 +927,8 @@ class GraphModule(torch.nn.Module):
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
self.assertEqual(len(mm_nodes), 4)
|
||||
self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward")
|
||||
self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward")
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
|
||||
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
|
||||
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
|
||||
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)
|
||||
|
||||
@ -4101,6 +4101,53 @@ if HAS_CUDA_AND_TRITON:
|
||||
compiled_out = compiled_foo(x)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
# Use autotune_at_compile_time=True to test standalone_compile
|
||||
@parametrize("autotune_at_compile_time", [True, False])
|
||||
@config.patch("graph_partition", True)
|
||||
def test_graph_partition_kernel_reuse(self, autotune_at_compile_time):
|
||||
def foo(x):
|
||||
# partition 1
|
||||
x1 = x @ x
|
||||
y1 = x1 + 1
|
||||
z_cpu = y1.cpu() + 1
|
||||
# partition 2
|
||||
# partition 2 should reuse the fused triton kernel generated
|
||||
# in partition 1
|
||||
x2 = z_cpu.to("cuda") @ z_cpu.to("cuda")
|
||||
y2 = x2 + 1
|
||||
return y1, y2
|
||||
|
||||
with config.patch(
|
||||
"triton.autotune_at_compile_time", autotune_at_compile_time
|
||||
):
|
||||
compiled_foo = torch.compile(foo)
|
||||
x = torch.randn((20, 20), device="cuda")
|
||||
eager_out = foo(x)
|
||||
compiled_out, code = run_and_get_code(compiled_foo, x)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
if autotune_at_compile_time:
|
||||
# auto-tuning block should only appear once. We generate auto-tuning code
|
||||
# for all the kernels no matter if they are defined in the main graph or
|
||||
# subgraph, to avoid the overhead of executing multiple auto-tuning code blocks.
|
||||
FileCheck().check_count(
|
||||
"Compile-time auto-tuning block", 1, exactly=True
|
||||
).run(code[0])
|
||||
# triton_poi_fused_add_ should appear twice, first in the auto-tuning block,
|
||||
# and then in the main code block
|
||||
FileCheck().check_count(
|
||||
"def triton_poi_fused_add_", 2, exactly=True
|
||||
).run(code[0])
|
||||
# cpu kernel definition should only appence once, not in the auto-tuning block
|
||||
FileCheck().check_count(
|
||||
"cpp_fused__to_copy_add_1 = ", 1, exactly=True
|
||||
).run(code[0])
|
||||
else:
|
||||
# triton_poi_fused_add_ should appear once, because of kernel reuse
|
||||
FileCheck().check_count(
|
||||
"def triton_poi_fused_add_", 1, exactly=True
|
||||
).run(code[0])
|
||||
|
||||
def test_meta_tensor(self):
|
||||
def foobar(x, y):
|
||||
return x * 2, y * 3
|
||||
|
||||
@ -4,8 +4,9 @@ from functools import partial
|
||||
from unittest import skipIf
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.ir import Pointwise
|
||||
from torch._inductor.lowering import make_pointwise, register_lowering
|
||||
from torch._inductor.lowering import make_fallback, make_pointwise, register_lowering
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.virtualized import ops
|
||||
from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu
|
||||
@ -237,6 +238,17 @@ class TestCustomLowering(InductorTestCase):
|
||||
out2 = fn_opt(a, b)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
@config.patch(joint_graph_constant_folding=False)
|
||||
def test_constant_creation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + torch.tensor(1)
|
||||
|
||||
make_fallback(torch.ops.aten.lift_fresh_copy.default)
|
||||
self.assertTrue(
|
||||
torch.allclose(torch.compile(M())(torch.ones(3)), torch.ones(3) + 1)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
@ -492,6 +492,36 @@ class PackedSequenceTest(TestCase):
|
||||
torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
|
||||
)
|
||||
|
||||
def test_empty_packed_sequence(self):
|
||||
"""
|
||||
Regression test for https://github.com/pytorch/pytorch/issues/149622
|
||||
Tests that pad_packed_sequence and unpack_sequence handle empty tensors
|
||||
without segmentation fault (CVE-2025-2998, CVE-2025-2999)
|
||||
"""
|
||||
# Test case 1: pad_packed_sequence with empty tensors
|
||||
# Previously caused segmentation fault
|
||||
empty_data = torch.randn(0, 5)
|
||||
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
|
||||
empty_packed = rnn_utils.PackedSequence(
|
||||
empty_data, empty_batch_sizes, None, None
|
||||
)
|
||||
|
||||
# Should not crash - either return empty result or raise informative error
|
||||
with self.assertRaises(RuntimeError):
|
||||
rnn_utils.pad_packed_sequence(empty_packed, batch_first=True)
|
||||
|
||||
# Test case 2: unpack_sequence with empty tensors
|
||||
# Previously caused segmentation fault
|
||||
empty_data = torch.tensor([])
|
||||
empty_batch_sizes = torch.tensor([], dtype=torch.int64)
|
||||
packed = rnn_utils.PackedSequence(
|
||||
data=empty_data, batch_sizes=empty_batch_sizes
|
||||
)
|
||||
|
||||
# Should not crash - either return empty list or raise informative error
|
||||
with self.assertRaises(RuntimeError):
|
||||
rnn_utils.unpack_sequence(packed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -2320,6 +2320,8 @@ if sys.version_info >= (3, 11):
|
||||
torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable
|
||||
torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
torch_c_binding_in_graph_functions["math.fma"] = TorchInGraphFunctionVariable
|
||||
|
||||
# In graph functions (including constant folding) that are not C bindings
|
||||
# NOTE: [Cacheability of in-graph torch functions]
|
||||
|
||||
@ -10,7 +10,10 @@ from torch.fx import has_side_effect, Proxy
|
||||
from .. import graph_break_hints
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import TYPE_CHECKING, unimplemented
|
||||
from ..graph_bytecode_inputs import get_external_object_by_index
|
||||
from ..graph_bytecode_inputs import (
|
||||
get_external_object_by_index,
|
||||
register_graph_created_object,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import FxTracebackAnnotateVariable
|
||||
@ -28,6 +31,26 @@ from torch._library.custom_ops import custom_op
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
def new_event(*args: Any, **kwargs: Any) -> int:
|
||||
event = torch.Event(*args, **kwargs)
|
||||
return register_graph_created_object(
|
||||
event,
|
||||
EventVariable.make_construct_in_graph_event_fn(
|
||||
TupleVariable([]), ConstDictVariable({})
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
|
||||
stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
|
||||
return register_graph_created_object(
|
||||
stream,
|
||||
StreamVariable.make_construct_in_graph_stream_fn(
|
||||
TupleVariable([]), ConstDictVariable({})
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_stream_by_index(index: int) -> torch.Stream:
|
||||
stream = get_external_object_by_index(index)
|
||||
assert isinstance(stream, torch.Stream), (
|
||||
@ -115,6 +138,24 @@ def _(
|
||||
has_side_effect(torch.ops.streams.wait_event.default)
|
||||
|
||||
|
||||
@custom_op("streams::wait_stream", mutates_args=())
|
||||
def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
|
||||
waiting = _get_stream_by_index(waiting_stream_index)
|
||||
waited_on = _get_stream_by_index(waited_on_stream_index)
|
||||
waiting.wait_stream(waited_on)
|
||||
|
||||
|
||||
@wait_stream.register_fake
|
||||
def _(
|
||||
event_index: int,
|
||||
stream_index: int,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
has_side_effect(torch.ops.streams.wait_stream.default)
|
||||
|
||||
|
||||
class SymbolicStreamState:
|
||||
"""Track the currently entered stream if any"""
|
||||
|
||||
|
||||
@ -603,6 +603,21 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
VariableTracker.build(tx, polyfills.radians), args, kwargs
|
||||
)
|
||||
|
||||
if hasattr(math, "fma"): # Python 3.13+
|
||||
|
||||
@register(math.fma)
|
||||
def handle_fma(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
if len(args) != 3 or kwargs:
|
||||
return None
|
||||
|
||||
if all(isinstance(arg, variables.TensorVariable) for arg in args):
|
||||
x, y, z = args
|
||||
addcmul_fn = TorchInGraphFunctionVariable(torch.addcmul)
|
||||
return addcmul_fn.call_function(tx, [z, x, y], {})
|
||||
|
||||
# Use math.fma if constants
|
||||
return None
|
||||
|
||||
@register(torch.is_inference_mode_enabled)
|
||||
def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"):
|
||||
unimplemented(
|
||||
|
||||
@ -27,7 +27,6 @@ from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_proxy_tensor_disable_update_tensor_tracker,
|
||||
get_proxy_mode,
|
||||
maybe_disable_thunkify,
|
||||
maybe_enable_thunkify,
|
||||
)
|
||||
@ -296,10 +295,6 @@ def create_joint(
|
||||
(outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs(
|
||||
fn, primals
|
||||
)
|
||||
mode = get_proxy_mode()
|
||||
assert mode is not None
|
||||
for node in mode.tracer.graph.nodes:
|
||||
node.meta["partitioner_tag"] = "is_forward"
|
||||
|
||||
# TODO: I think this hook can also be eliminated now
|
||||
if joint_fn_handle and joint_fn_handle.post_forward:
|
||||
|
||||
@ -51,7 +51,6 @@ from ._activation_checkpointing.knapsack import (
|
||||
)
|
||||
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
||||
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
|
||||
from ._aot_autograd.functional_utils import assert_functional_graph
|
||||
from ._aot_autograd.logging_utils import get_aot_graph_name
|
||||
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
|
||||
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
|
||||
@ -298,10 +297,6 @@ def _has_tag_is_backward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_backward"
|
||||
|
||||
|
||||
def _has_tag_is_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "is_forward"
|
||||
|
||||
|
||||
def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
|
||||
return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
|
||||
|
||||
@ -1026,95 +1021,105 @@ def default_partition(
|
||||
Returns:
|
||||
Returns the generated forward and backward Fx graph modules.
|
||||
"""
|
||||
# Respect the original placement of ops rather than rely on dataflow.
|
||||
forward_nodes = []
|
||||
last_node = None
|
||||
for node in joint_module.graph.nodes:
|
||||
if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
|
||||
last_node = node
|
||||
assert last_node is not None
|
||||
for node in joint_module.graph.nodes:
|
||||
if not _is_tangent(node):
|
||||
forward_nodes.append(node)
|
||||
if node is last_node:
|
||||
break
|
||||
if has_recomputable_ops(joint_module):
|
||||
return min_cut_rematerialization_partition(
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
forward_node_names = OrderedSet(
|
||||
node.name for node in forward_nodes if node.op != "output"
|
||||
node.name for node in forward_only_graph.nodes if node.op != "output"
|
||||
)
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
assert_functional_graph(joint_module.graph)
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
|
||||
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
|
||||
order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
def is_mutated_later_in_fw(node):
|
||||
if _has_tag_is_backward(node):
|
||||
return False
|
||||
tensor_arg_aliases = [
|
||||
x
|
||||
for x in node.args
|
||||
if isinstance(x, fx.Node)
|
||||
and "val" in x.meta
|
||||
and isinstance(x.meta["val"], torch.Tensor)
|
||||
]
|
||||
while len(tensor_arg_aliases) > 0:
|
||||
a = tensor_arg_aliases.pop()
|
||||
for u in a.users:
|
||||
if not isinstance(u.target, torch._ops.OpOverload):
|
||||
continue
|
||||
# If we witness a mutation on our node later, and that mutation is not "must be in backward",
|
||||
# then our node needs to be computed in the forward (otherwise we will compute it on the mutated values)
|
||||
if (
|
||||
# one of the args was mutated
|
||||
u.target._schema.is_mutable
|
||||
# and the mutation happens "later"
|
||||
and order[u] > order[node]
|
||||
# and the mutation happened during the forward
|
||||
and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u))
|
||||
):
|
||||
for idx, alias_info in enumerate(u.target._schema.arguments):
|
||||
if alias_info.is_write and u.args[idx] is a:
|
||||
return True
|
||||
elif u.target.is_view:
|
||||
tensor_arg_aliases.append(u)
|
||||
return False
|
||||
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.name not in forward_node_names:
|
||||
# if a node isn't "required" to be in the forward, but any of its arguments
|
||||
# are later mutated in the forward, then it must have been run in the forward
|
||||
# (if not, and the node's arg was saved for backward, we would have mutated a saved value)
|
||||
# NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated
|
||||
if is_mutated_later_in_fw(node):
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if is_sym_node(node):
|
||||
# Symints must be kept separate from tensors so that PythonFunction only calls
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
saved_sym_nodes.append(node)
|
||||
continue
|
||||
if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
|
||||
saved_values.append(node)
|
||||
continue
|
||||
if node.is_impure(impure_random=False) and node.op not in (
|
||||
"placeholder",
|
||||
"output",
|
||||
):
|
||||
# See is_impure in torch/fx/node.py
|
||||
assert not graph_has_recomputable_ops, (
|
||||
"Trying to apply AC on a graph with impure op",
|
||||
node,
|
||||
node.target,
|
||||
)
|
||||
saved_values.append(node)
|
||||
continue
|
||||
backward_usages = [n for n in node.users if n.name not in forward_node_names]
|
||||
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
continue
|
||||
if (
|
||||
elif (
|
||||
"tensor_meta" not in node.meta
|
||||
and node.op == "call_function"
|
||||
and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor)
|
||||
):
|
||||
assert all(user.target == operator.getitem for user in node.users)
|
||||
continue
|
||||
if not must_recompute(node):
|
||||
saved_values.append(node)
|
||||
|
||||
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
|
||||
users = node.users
|
||||
assert all(user.target is operator.getitem for user in users)
|
||||
saved_values.extend(users)
|
||||
else:
|
||||
backward_usages = [
|
||||
n for n in node.users if n.name not in forward_node_names
|
||||
]
|
||||
if "tensor_meta" in node.meta and all(
|
||||
is_sym_node(n) for n in backward_usages
|
||||
):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
else:
|
||||
saved_values.append(node)
|
||||
saved_values = list(dict.fromkeys(saved_values).keys())
|
||||
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
||||
|
||||
if config._sync_decision_cross_ranks:
|
||||
saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
|
||||
|
||||
if static_lifetime_input_nodes is None:
|
||||
static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
|
||||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||
return _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
@ -1122,24 +1127,6 @@ def default_partition(
|
||||
static_lifetime_input_nodes=static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if graph_has_recomputable_ops:
|
||||
if graph_has_recomputable_rng_ops:
|
||||
fw_module, bw_module = functionalize_rng_ops(
|
||||
joint_module, fw_module, bw_module, len(saved_sym_nodes)
|
||||
)
|
||||
bw_module = reordering_to_mimic_autograd_engine(bw_module)
|
||||
|
||||
# raise all getitem ops to as early as possible
|
||||
# this is helpful for memory, especially in the case of aot_eager backend
|
||||
fw_module = raise_getitems(fw_module)
|
||||
bw_module = raise_getitems(bw_module)
|
||||
|
||||
fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
|
||||
if len(node_info.required_bw_nodes) > 0:
|
||||
bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
|
||||
|
||||
return fw_module, bw_module
|
||||
|
||||
|
||||
INT_INF = int(1e6)
|
||||
|
||||
@ -1634,9 +1621,7 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
|
||||
break
|
||||
|
||||
|
||||
def cleanup_recompute_tags(
|
||||
joint_module: fx.GraphModule, *, is_default_partition: bool
|
||||
) -> fx.GraphModule:
|
||||
def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
|
||||
"""
|
||||
If there are two consecutive checkpointed blocks with no operator in
|
||||
between, we would still want to stash the tensor at the boundary of
|
||||
@ -1673,16 +1658,6 @@ def cleanup_recompute_tags(
|
||||
# Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
|
||||
# in forward graph outputs. With this, we can break the above circular dependency.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
elif (
|
||||
"ac_graph_id" not in node.meta
|
||||
and any(must_recompute(user) for user in node.users)
|
||||
and is_default_partition
|
||||
):
|
||||
# This node is not part of the AC region and a user is marked as recompute.
|
||||
# This means it's an input to the AC region and we should save it.
|
||||
# For ease of landing, gate this to default partitioner only, but we should think
|
||||
# about flipping the switch in general as well.
|
||||
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
|
||||
return joint_module
|
||||
|
||||
|
||||
@ -2790,59 +2765,6 @@ def thread_graphsafe_rng_from_hops(module, is_backward):
|
||||
return module
|
||||
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
|
||||
def min_cut_rematerialization_partition(
|
||||
joint_module: fx.GraphModule,
|
||||
_joint_inputs,
|
||||
@ -2891,16 +2813,68 @@ def min_cut_rematerialization_partition(
|
||||
graph_has_recomputable_ops = has_recomputable_ops(joint_module)
|
||||
graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
|
||||
if graph_has_recomputable_ops:
|
||||
joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
|
||||
joint_module = cleanup_recompute_tags(joint_module)
|
||||
if not config.unsafe_allow_optimization_of_collectives:
|
||||
force_save_collectives(joint_module)
|
||||
force_save_bw_mutation_src(joint_module)
|
||||
|
||||
def classify_nodes(joint_module, static_lifetime_input_indices):
|
||||
name_to_node = get_name_to_node(joint_module.graph)
|
||||
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.op == "placeholder" and "tangents" in node.target:
|
||||
required_bw_nodes.add(node)
|
||||
elif _must_be_in_backward(node):
|
||||
required_bw_nodes.add(node)
|
||||
|
||||
if node in required_bw_nodes:
|
||||
required_bw_nodes.update(node.users)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(
|
||||
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
|
||||
)
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
)
|
||||
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
node
|
||||
for node in joint_module.graph.nodes
|
||||
if node not in required_fw_nodes and node not in required_bw_nodes
|
||||
)
|
||||
static_lifetime_input_nodes = OrderedSet(
|
||||
p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
|
||||
)
|
||||
fw_cnt = 0
|
||||
fw_order = {}
|
||||
for node in joint_module.graph.nodes:
|
||||
if node in required_fw_nodes:
|
||||
fw_order[node] = fw_cnt
|
||||
fw_cnt += 1
|
||||
return NodeInfo(
|
||||
inputs,
|
||||
required_fw_nodes,
|
||||
required_bw_nodes,
|
||||
unclaimed_nodes,
|
||||
fw_order,
|
||||
static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
if static_lifetime_input_indices is None:
|
||||
static_lifetime_input_indices = []
|
||||
node_info = classify_nodes(
|
||||
joint_module, static_lifetime_input_indices, num_fwd_outputs
|
||||
)
|
||||
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
|
||||
|
||||
# networkx blows up on graphs with no required backward nodes
|
||||
# Since there's nothing to partition anyway, and the default partitioner can "handle"
|
||||
|
||||
@ -627,7 +627,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
filename=__file__,
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -2259,7 +2259,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||
gpu: bool = True,
|
||||
cpp_definition: Optional[str] = None,
|
||||
):
|
||||
if config.triton.autotune_at_compile_time:
|
||||
if config.triton.autotune_at_compile_time and gpu:
|
||||
body = self._format_kernel_definition(
|
||||
kernel_name, kernel_body, metadata=metadata
|
||||
)
|
||||
@ -3745,6 +3745,13 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
|
||||
super().__init__()
|
||||
|
||||
root = self.get_root_graph()
|
||||
# Only generate auto-tuning block in the main graph
|
||||
self.kernel_autotune_defs = root.kernel_autotune_defs
|
||||
self.kernel_autotune_calls = root.kernel_autotune_calls
|
||||
# Only store kernel src to name mapping in the main graph
|
||||
self.src_to_kernel = root.src_to_kernel
|
||||
|
||||
def set_launcher_fn_name(self) -> None:
|
||||
# This sets up the name of the function containing the launcher code of
|
||||
# the subgraph.
|
||||
@ -3837,3 +3844,16 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen):
|
||||
# V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
||||
# )
|
||||
self.parent_wrapper.write_get_raw_stream_header_once()
|
||||
|
||||
@cache_on_self
|
||||
def get_root_graph(self) -> PythonWrapperCodegen:
|
||||
root: PythonWrapperCodegen | SubgraphPythonWrapperCodegen = self
|
||||
while isinstance(root, SubgraphPythonWrapperCodegen):
|
||||
root = root.parent_wrapper
|
||||
|
||||
assert isinstance(root, PythonWrapperCodegen)
|
||||
return root
|
||||
|
||||
def generate_and_run_autotune_block(self):
|
||||
# Only execute auto-tuning block in the main graph
|
||||
pass
|
||||
|
||||
@ -64,6 +64,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
)
|
||||
from torch.fx.node import Node
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._python_dispatch import _disable_current_modes
|
||||
from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing
|
||||
from torch.utils._sympy.symbol import SymT
|
||||
|
||||
@ -6135,9 +6136,12 @@ class ExternKernel(InputsKernel):
|
||||
if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
|
||||
return ShapeAsConstantBuffer(expr=x)
|
||||
if isinstance(x, Constant):
|
||||
return V.graph.add_tensor_constant(
|
||||
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
|
||||
)
|
||||
# We need to unset fake mode, or else the torch.tensor() call will
|
||||
# turn into a FakeTensor
|
||||
with _disable_current_modes():
|
||||
return V.graph.add_tensor_constant(
|
||||
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
|
||||
)
|
||||
if isinstance(x, ConstantBuffer):
|
||||
return x
|
||||
if isinstance(x, TensorBox):
|
||||
|
||||
@ -3607,24 +3607,13 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
configs,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -702,7 +702,7 @@ def exp2(a):
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
@out_wrapper()
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a,"),
|
||||
type_promoting_args=("a",),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
|
||||
)
|
||||
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/csrc/stable/device_struct.h>
|
||||
@ -120,7 +119,7 @@ struct FromImpl<ScalarType> {
|
||||
case ScalarType::UInt64:
|
||||
return from(aoti_torch_dtype_uint64());
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported ScalarType, please file an issue describing your use case.");
|
||||
}
|
||||
@ -151,7 +150,7 @@ struct FromImpl<DeviceType> {
|
||||
case DeviceType::PrivateUse1:
|
||||
return from(aoti_torch_device_type_privateuse1());
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported DeviceType, please file an issue describing your use case.");
|
||||
}
|
||||
@ -379,7 +378,7 @@ struct ToImpl<ScalarType> {
|
||||
} else if (shim_scalartype == aoti_torch_dtype_uint64()) {
|
||||
return ScalarType::UInt64;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported ScalarType ",
|
||||
std::to_string(shim_scalartype),
|
||||
@ -409,7 +408,7 @@ struct ToImpl<DeviceType> {
|
||||
} else if (shim_devicetype == aoti_torch_device_type_privateuse1()) {
|
||||
return DeviceType::PrivateUse1;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported DeviceType ",
|
||||
std::to_string(shim_devicetype),
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, Union
|
||||
from typing import Any, NamedTuple, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
|
||||
|
||||
import torch
|
||||
@ -17,6 +17,9 @@ from torch.utils._pytree import tree_map_only
|
||||
from torch.utils.weak import WeakIdKeyDictionary, weakref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
_TOTAL_KEY = "Total"
|
||||
|
||||
__all__ = ["FSDPMemTracker"]
|
||||
@ -365,14 +368,28 @@ class FSDPMemTracker(MemTracker):
|
||||
# `FSDPParamGroup.post_forward` because during AC these won't be called.
|
||||
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
|
||||
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
|
||||
# get the unique _MultiHandlers/RemoveHandlers and store in dictionary
|
||||
# the _MultiHandlers object will only need to be grabbed once.
|
||||
unique_handlers: dict[RemovableHandle, bool] = {}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for module in self._root_mod.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
fsdp_state = module._get_fsdp_state()
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
if not unique_handlers.get(fsdp_state._pre_forward_hook_handle):
|
||||
unique_handlers[fsdp_state._pre_forward_hook_handle] = True
|
||||
if not unique_handlers.get(fsdp_state._post_forward_hook_handle):
|
||||
unique_handlers[fsdp_state._post_forward_hook_handle] = True
|
||||
# call remove on the handles once
|
||||
for f_hook_handle in unique_handlers.keys():
|
||||
f_hook_handle.remove()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
for module in self._root_mod.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
fsdp_state = module._get_fsdp_state()
|
||||
if fsdp_param_group := fsdp_state._fsdp_param_group:
|
||||
self._instrument_fsdp_sharded_params_grads(fsdp_param_group)
|
||||
fsdp_state._pre_forward_hook_handle.remove()
|
||||
fsdp_state._post_forward_hook_handle.remove()
|
||||
fsdp_state._pre_forward_hook_handle = (
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
module.register_forward_pre_hook(
|
||||
|
||||
@ -194,6 +194,10 @@ else:
|
||||
_rank_map: Optional[torch.Tensor] = None,
|
||||
_root_mesh: Optional["DeviceMesh"] = None,
|
||||
) -> None:
|
||||
# no-op in OSS, logs API usage metrics in meta-internal runs
|
||||
torch._C._log_api_usage_once(
|
||||
"torch.distributed.device_mesh.DeviceMesh.__init__"
|
||||
)
|
||||
if mesh is not None:
|
||||
if _layout is not None or _rank_map is not None:
|
||||
raise TypeError(
|
||||
|
||||
@ -359,6 +359,10 @@ class ShardingPropagator:
|
||||
"""
|
||||
Propagate the sharding for an operator given the op_schema.
|
||||
"""
|
||||
# no-op in OSS, logs API usage metrics in meta-internal runs
|
||||
torch._C._log_api_usage_once(
|
||||
"torch.distributed.tensor._sharding_prop.ShardingPropagator.propogate_op_sharding_non_cached"
|
||||
)
|
||||
# special case op, we don't need to propagate for local
|
||||
# scalar. TODO: figure out a better way to handle this
|
||||
if op_schema.op is aten._local_scalar_dense.default:
|
||||
|
||||
@ -398,6 +398,9 @@ def load(
|
||||
Under active development, saved files may not be usable in newer versions
|
||||
of PyTorch.
|
||||
|
||||
.. warning::
|
||||
:func:`torch.export.load()` uses pickle under the hood to load models. **Never load data from an untrusted source.**
|
||||
|
||||
Loads an :class:`ExportedProgram` previously saved with
|
||||
:func:`torch.export.save <torch.export.save>`.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import dataclasses
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import Callable
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
@ -721,7 +721,18 @@ def _combine_args(f, args, kwargs) -> dict[str, Any]:
|
||||
else inspect.signature(f)
|
||||
)
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
return signature.bind(*args, **kwargs).arguments
|
||||
combined_args = signature.bind(*args, **kwargs).arguments
|
||||
# if `args` is in the key, flatten it into args_0, args_1, ...
|
||||
if "args" in combined_args:
|
||||
flattened_args = {f"args_{i}": v for i, v in enumerate(combined_args["args"])}
|
||||
combined_args = OrderedDict({**combined_args, **flattened_args})
|
||||
del combined_args["args"]
|
||||
# flatten kwargs into combined_args
|
||||
if "kwargs" in combined_args:
|
||||
for k, v in combined_args["kwargs"].items():
|
||||
combined_args[k] = v
|
||||
del combined_args["kwargs"]
|
||||
return combined_args
|
||||
|
||||
|
||||
class ShapesCollection:
|
||||
|
||||
Reference in New Issue
Block a user