Compare commits

...

19 Commits

Author SHA1 Message Date
27e0a198be fix hipify docstring 2025-11-10 07:56:31 -08:00
256b61734f [BE] documenting more functions 2025-11-10 07:52:33 -08:00
59307ca1bc [BE] adding documentation (#167334)
`torch.ao.quantization` and `torch.fx.experimental`

<img width="833" height="518" alt="Screenshot 2025-11-07 at 3 20 54 PM" src="https://github.com/user-attachments/assets/47b72f28-29bd-4bab-b41f-24d97419e411" />
<img width="892" height="560" alt="Screenshot 2025-11-07 at 3 20 45 PM" src="https://github.com/user-attachments/assets/129825ab-6706-41f2-964d-8774debab18c" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167334
Approved by: https://github.com/janeyx99
2025-11-10 14:46:42 +00:00
c28475db7c Update slow tests (#166844)
This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/weekly.yml).
Update the list of slow tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166844
Approved by: https://github.com/pytorchbot
2025-11-10 12:39:27 +00:00
74aec83841 [xla hash update] update the pinned xla hash (#167452)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned xla hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167452
Approved by: https://github.com/pytorchbot
2025-11-10 12:03:01 +00:00
52e744d68a [DTensor] Support convert StridedShard to shard order and vice versa (#166740)
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="820" height="431" alt="image" src="https://github.com/user-attachments/assets/f53b4b24-2523-42cc-ad6f-41f3c280db70" />

We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166740
Approved by: https://github.com/ezyang
2025-11-10 09:35:10 +00:00
3cfbf98ea9 [xpu][feature] Add XPU support on torch.accelerator.get_memory_info (#162564)
# Motivation
Support XPU for `torch.accelerator.get_memory_info`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162564
Approved by: https://github.com/albanD
ghstack dependencies: #156812
2025-11-10 05:34:49 +00:00
47db55258b [MPS] sparse sparse mm (#167013)
Sparse sparse mm op implementation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167013
Approved by: https://github.com/malfet
2025-11-10 05:27:49 +00:00
50af6f3393 [MPS] erfinv for sparse mps (#166711)
Should be merged after #166708
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166711
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-11-10 05:25:31 +00:00
e545ba2d34 [DTensor] Fix Conv behavior for replicate stategy (#167402)
Pass `dim_map` to `_requires_data_exchange` and return False if both spatial and channels dimensions are replicated

Modify `test_conv1d` and `test_conv3d` to check values rather than just shape, and replicate `conv3d` across batch dimension

In general, feels like current Convolution implementation was written to work only if tensor is sharded across last dimention

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167402
Approved by: https://github.com/ezyang
2025-11-10 05:13:42 +00:00
a058bbdd6f [xpu][test] Enable profiler test for XPU (#165423)
Fixes #165130

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165423
Approved by: https://github.com/EikanWang, https://github.com/atalman, https://github.com/mlazos
2025-11-10 04:02:59 +00:00
2c78080ec0 Register functorch XPU/HPU dispatch keys (#167095)
Fixes TestOperatorsXPU.test_data_write_errors_under_transform_xpu https://github.com/intel/torch-xpu-ops/issues/2237

Tests on other devices throw runtime error "_mutating directly with `.data` inside functorch transform is not allowed._", but XPU/HPU fails earlier on `_has_compatible_shallow_copy_type`. This check is not met only when calling tensor.data inside functorch call.

```cpp
bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) {
  return self.unsafeGetTensorImpl()->has_compatible_shallow_copy_type(
      from.key_set());
}
```

### t.data
| Tensor | Device | Dispatch Keys |
|--------|---------|---------------|
| `self` | `xpu` | `XPU, ADInplaceOrView, AutogradXPU, AutocastXPU` |
| `from` | `cpu` | `CPU, ADInplaceOrView, AutogradCPU, AutocastCPU` |

### t.data inside functorch transform
| Tensor | Device | Dispatch Keys |
|--------|---------|---------------|
| `self` | `xpu` | `ADInplaceOrView, AutogradOther, FuncTorchGradWrapper` |
| `from` | `cpu` | `CPU, ADInplaceOrView, AutogradCPU, AutocastCPU, FuncTorchGradWrapper` |

### t.data inside functorch transform + XPU dispatch key
| Tensor | Device | Dispatch Keys |
|--------|---------|---------------|
| `self` | `xpu` | `XPU, ADInplaceOrView, AutogradXPU, AutocastXPU, FuncTorchGradWrapper` |
| `from` | `cpu` | `CPU, ADInplaceOrView, AutogradCPU, AutocastCPU, FuncTorchGradWrapper` |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167095
Approved by: https://github.com/guangyey, https://github.com/albanD
2025-11-10 03:10:22 +00:00
fe6615e397 Swap pallas test shard to 12.8 (#167428)
Getting some weird failures building cuda13, lets stick to what we know works
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167428
Approved by: https://github.com/jansel
2025-11-10 02:42:35 +00:00
abf31db2cc Introduce a new API torch.accelerator.get_memory_info (#156812)
# Motivation
`torch.cuda.mem_get_info` and `torch.xpu.mem_get_info` are widely used in other popular repos, such as
- 076313bd09/python/sglang/srt/utils.py (L378),
- 7ecc2d7f39/src/accelerate/utils/modeling.py (L822),
- 7ba34b1241/vllm/worker/worker.py (L150).
-
This PR introduces a unified API `torch.accelerator.get_memory_info` to cover this scenario.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156812
Approved by: https://github.com/albanD
2025-11-10 01:57:39 +00:00
a4c7856112 [Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167340)
Summary: This is a reland of https://github.com/pytorch/pytorch/pull/165036, which previously contained a minor bug in the logic that determined whether the kernel should be enabled. As a result, it was incorrectly activated on non-Blackwell GPUs.

Test Plan:
Inductor test (fbcode):
`INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1"`

Tritonbench (fbcode):
`clear; CUDA_VISIBLE_DEVICES=7 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/opt //pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1" -- --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_cute_grouped_mm --precision bf16  --num-inputs 1 --metrics tflops,accuracy`

Tritonbench(oss):
`clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16  --num-inputs 1 --metrics tflops,accuracy`

Unit Tests(oss):
`clear; python test/inductor/test_cutedsl_grouped_mm.py`

Differential Revision: D86537373

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167340
Approved by: https://github.com/jananisriram
2025-11-10 00:29:07 +00:00
afb014541b Separately handle null data_ptr storages when creating unique ID (#167405)
## Summary
Previously fake/functionalized tensors that have `null` storage_ptr could segfault when checking for `.expired()` on weak storage ref, so handle `nullptr` storages separately, without checking their weakrefs.

Diagnosis and PR created by codex
------
[Codex Task](https://chatgpt.com/codex/tasks/task_e_690ea8790054832f90eaffb37ee0d8c8)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167405
Approved by: https://github.com/Skylion007
2025-11-09 23:13:56 +00:00
b91a2ab892 [2/N] Use context managers (#167404)
This PR fixes more context manager usage in Python code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167404
Approved by: https://github.com/mlazos
2025-11-09 13:38:14 +00:00
14a845a4ec [2/N] Use Python 3.10 typing (#167167)
This PR applies new `Union` and `Optional` typing syntax to some files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167167
Approved by: https://github.com/XuehaiPan, https://github.com/mlazos
2025-11-09 12:11:45 +00:00
5135ace3a3 Enable ruff UP035 rule (#167307)
This PR enables `UP035` rule of ruff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167307
Approved by: https://github.com/Lucaskabela
2025-11-09 06:40:03 +00:00
144 changed files with 2692 additions and 1786 deletions

View File

@ -260,8 +260,8 @@ case "$tag" in
HALIDE=yes
TRITON=yes
;;
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
CUDA_VERSION=13.0.0
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
CUDA_VERSION=12.8.1
ANACONDA_PYTHON_VERSION=3.12
GCC_VERSION=11
PALLAS=yes

View File

@ -8,9 +8,11 @@ from abc import ABC, abstractmethod
try:
from typing import Any, Callable, Required, TypedDict # Python 3.11+
from collections.abc import Callable # Python 3.11+
from typing import Any, Required, TypedDict
except ImportError:
from typing import Any, Callable, TypedDict
from collections.abc import Callable
from typing import Any, TypedDict
from typing_extensions import Required # Fallback for Python <3.11

View File

@ -168,14 +168,16 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
# shellcheck disable=SC1091
source /opt/intel/oneapi/compiler/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/umf/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/ccl/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/mpi/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/pti/latest/env/vars.sh
# Enable XCCL build
export USE_XCCL=1
export USE_MPI=0
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
export USE_KINETO=0
export TORCH_XPU_ARCH_LIST=pvc
fi

View File

@ -208,6 +208,8 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
source /opt/intel/oneapi/ccl/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/mpi/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/pti/latest/env/vars.sh
# Check XPU status before testing
timeout 30 xpu-smi discovery || true
fi
@ -337,7 +339,7 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}

View File

@ -1 +1 @@
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a

View File

@ -1,10 +1,11 @@
# Delete old branches
import os
import re
from collections.abc import Callable
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable
from typing import Any
from github_utils import gh_fetch_json_dict, gh_graphql
from gitutils import GitRepo

View File

@ -8,10 +8,11 @@ import re
import subprocess
import sys
import warnings
from collections.abc import Callable
from enum import Enum
from functools import cache
from logging import info
from typing import Any, Callable, Optional
from typing import Any, Optional
from urllib.request import Request, urlopen
import yaml

View File

@ -11,7 +11,8 @@ import sys
import time
import urllib
import urllib.parse
from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any, Optional
from urllib.request import Request, urlopen

View File

@ -3,8 +3,9 @@
import json
import os
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Callable, cast, Optional, Union
from typing import Any, cast, Optional, Union
from urllib.error import HTTPError
from urllib.parse import quote
from urllib.request import Request, urlopen

View File

@ -4,10 +4,10 @@ import os
import re
import tempfile
from collections import defaultdict
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from datetime import datetime
from functools import wraps
from typing import Any, Callable, cast, Optional, TypeVar, Union
from typing import Any, cast, Optional, TypeVar, Union
T = TypeVar("T")

View File

@ -17,12 +17,12 @@ import re
import time
import urllib.parse
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from functools import cache
from pathlib import Path
from re import Pattern
from typing import Any, Callable, cast, NamedTuple, Optional
from typing import Any, cast, NamedTuple, Optional
from warnings import warn
import yaml

View File

@ -67,7 +67,7 @@ jobs:
pytorch-linux-jammy-py3.10-gcc11,
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-noble-xpu-n-py3,
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,

View File

@ -86,8 +86,8 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
cuda-arch-list: '8.9'
runner: linux.8xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"

1
.gitignore vendored
View File

@ -127,6 +127,7 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

View File

@ -94,6 +94,11 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
}
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
}
} // namespace at::accelerator
namespace at {

View File

@ -157,6 +157,8 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
DispatchKey::Negative,
DispatchKey::Conjugate,
DispatchKey::XLA,
DispatchKey::XPU,
DispatchKey::HPU,
DispatchKey::CUDA,
DispatchKey::CPU,
DispatchKey::PrivateUse1,

View File

@ -4292,6 +4292,7 @@
dispatch:
SparseCPU: sparse_sparse_matmul_cpu
SparseCUDA: sparse_sparse_matmul_cuda
SparseMPS: sparse_sparse_matmul_mps
autogen: _sparse_sparse_matmul.out
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
@ -9832,7 +9833,7 @@
structured_delegate: erfinv.out
variants: method, function
dispatch:
SparseCPU, SparseCUDA: erfinv_sparse
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
tags: pointwise
@ -9841,7 +9842,7 @@
structured_delegate: erfinv.out
variants: method
dispatch:
SparseCPU, SparseCUDA: erfinv_sparse_
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
tags: pointwise
@ -9851,7 +9852,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: erfinv_out
SparseCPU, SparseCUDA: erfinv_sparse_out
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
tags: pointwise

View File

@ -10,6 +10,10 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_coalesce_native.h>
#include <ATen/ops/repeat_interleave_native.h>
#include <ATen/ops/cumsum.h>
#include <ATen/ops/_sparse_sparse_matmul_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/add_native.h>
@ -888,5 +892,114 @@ static void sparse_mask_intersection_out_mps_kernel(
/*coalesce_mask=*/false);
}
Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(),
"sparse_sparse_matmul_mps: both inputs must be sparse COO tensors");
TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(),
"sparse_sparse_matmul_mps: both inputs must be on MPS device");
TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2,
"sparse_sparse_matmul_mps: both inputs must be 2D matrices");
TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0,
"sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)");
TORCH_CHECK(mat1_.size(1) == mat2_.size(0),
"mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
"sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(),
" does not match mat2 dtype ", mat2_.scalar_type());
const auto device = mat1_.device();
auto A = mat1_.coalesce();
auto B = mat2_.coalesce();
const auto I = A.size(0);
const auto K = A.size(1);
const auto N = B.size(1);
const auto nnzA = A._nnz();
const auto nnzB = B._nnz();
// Early empty result, return an empty, coalesced tensor
if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) {
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
const auto computeDtype = at::result_type(mat1_, mat2_);
auto A_idx = A._indices().contiguous();
auto A_val = A._values().to(computeDtype).contiguous();
auto A_i = A_idx.select(0, 0).contiguous();
auto A_k = A_idx.select(0, 1).contiguous();
auto B_idx = B._indices().contiguous();
auto B_val = B._values().to(computeDtype).contiguous();
auto B_k = B_idx.select(0, 0).contiguous();
auto B_j = B_idx.select(0, 1).contiguous();
// csr-style row pointers for B by k (the shared dimension)
Tensor row_ptr_B;
{
auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong));
row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong));
build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B);
}
auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K);
auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K);
auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo);
auto counts = deg_B.index_select(0, A_k);
const int64_t P = counts.sum().item<int64_t>();
if (P == 0) {
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
auto group_ids = repeat_interleave_mps(counts);
// exclusive cumsum of counts
auto offsets = cumsum(counts, /*dim=*/0).sub(counts);
auto offsets_gather = offsets.index_select(0, group_ids);
auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather);
// Map each output element to its source B row and position
auto k_per_out = A_k.index_select(0, group_ids);
auto start_in_B = row_ptr_B.index_select(0, k_per_out);
auto seg_index = start_in_B.add(within);
// Assemble candidate coo pairs and values
auto i_out = A_i.index_select(0, group_ids).contiguous();
auto j_out = B_j.index_select(0, seg_index).contiguous();
auto vA_out = A_val.index_select(0, group_ids).contiguous();
auto vB_out = B_val.index_select(0, seg_index).contiguous();
auto v_out = vA_out.mul(vB_out);
// build (2, P) indices
auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous();
out_indices.select(0, 0).copy_(i_out);
out_indices.select(0, 1).copy_(j_out);
auto result = _sparse_coo_tensor_unsafe(
out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype));
result = result.coalesce();
if (result.scalar_type() != mat1_.scalar_type()) {
auto cast_vals = result._values().to(mat1_.scalar_type());
auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
return result;
}
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
} // namespace at::native

View File

@ -96,6 +96,10 @@ struct C10_API DeviceAllocator : public c10::Allocator {
// Resets peak memory usage statistics for the specified device
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
// Return the free memory size and total memory size in bytes for the
// specified device.
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
};
// This function is used to get the DeviceAllocator for a specific device type

View File

@ -345,6 +345,13 @@ class CUDAAllocator : public DeviceAllocator {
c10::DeviceIndex device,
std::shared_ptr<AllocatorState> pps) = 0;
virtual std::string name() = 0;
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
c10::DeviceGuard device_guard({at::kCUDA, device});
size_t free = 0;
size_t total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
return {free, total};
}
};
// Allocator object, statically initialized

View File

@ -926,15 +926,14 @@ class DeviceCachingAllocator {
(release_cached_blocks() && alloc_block(params, true));
}
if (!block_found) {
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device);
auto device_total = device_prop.global_mem_size;
const auto& raw_device = c10::xpu::get_raw_device(device);
const auto device_total =
raw_device.get_info<sycl::info::device::global_mem_size>();
// Estimate the available device memory when the SYCL runtime does not
// support the corresponding aspect (ext_intel_free_memory).
size_t device_free = device_prop.global_mem_size -
size_t device_free = device_total -
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto& raw_device = c10::xpu::get_raw_device(device);
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
// affected devices.
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
@ -1052,21 +1051,37 @@ class DeviceCachingAllocator {
}
}
std::pair<size_t, size_t> getMemoryInfo() {
const auto& device = c10::xpu::get_raw_device(device_index);
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
device.get_info<sycl::info::device::name>(),
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
const size_t free =
device.get_info<sycl::ext::intel::info::device::free_memory>();
return {free, total};
}
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_prop.global_mem_size);
static_cast<double>(device_total);
}
void setMemoryFraction(double fraction) {
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
auto device_total = device_prop.global_mem_size;
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
set_fraction = true;
}
@ -1240,6 +1255,11 @@ class XPUAllocator : public DeviceAllocator {
c10::xpu::get_raw_device(dev_to_access));
}
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
assertValidDevice(device);
return device_allocators[device]->getMemoryInfo();
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();

View File

@ -40,6 +40,7 @@
:nosignatures:
empty_cache
get_memory_info
max_memory_allocated
max_memory_reserved
memory_allocated

View File

@ -382,20 +382,6 @@ coverage_ignore_functions = [
# torch.ao.quantization.backend_config.tensorrt
"get_tensorrt_backend_config",
"get_tensorrt_backend_config_dict",
# torch.ao.quantization.backend_config.utils
"entry_to_pretty_str",
"get_fused_module_classes",
"get_fuser_method_mapping",
"get_fusion_pattern_to_extra_inputs_getter",
"get_fusion_pattern_to_root_node_getter",
"get_module_to_qat_module",
"get_pattern_to_dtype_configs",
"get_pattern_to_input_type_to_index",
"get_qat_module_classes",
"get_root_module_to_quantized_reference_module",
"pattern_to_human_readable",
"remove_boolean_dispatch_from_name",
# torch.ao.quantization.backend_config.x86
"get_x86_backend_config",
# torch.ao.quantization.fuse_modules
"fuse_known_modules",
@ -426,25 +412,6 @@ coverage_ignore_functions = [
"insert_observers_for_model",
"prepare",
"propagate_dtypes_for_known_nodes",
# torch.ao.quantization.fx.utils
"all_node_args_except_first",
"all_node_args_have_no_tensors",
"assert_and_get_unique_device",
"collect_producer_nodes",
"create_getattr_from_value",
"create_node_from_old_node_preserve_meta",
"get_custom_module_class_keys",
"get_linear_prepack_op_for_dtype",
"get_new_attr_name_with_prefix",
"get_non_observable_arg_indexes_and_types",
"get_qconv_prepack_op",
"get_skipped_module_name_and_classes",
"graph_module_from_producer_nodes",
"maybe_get_next_module",
"node_arg_is_bias",
"node_arg_is_weight",
"return_arg_list",
# torch.ao.quantization.pt2e.graph_utils
"bfs_trace_with_node_process",
"find_sequential_partitions",
"get_equivalent_types",
@ -860,80 +827,10 @@ coverage_ignore_functions = [
"get_latency_of_one_partition",
"get_latency_of_partitioned_graph",
"get_partition_to_latency_mapping",
# torch.fx.experimental.proxy_tensor
"decompose",
"disable_autocast_cache",
"disable_proxy_modes_tracing",
"dispatch_trace",
"extract_val",
"fake_signature",
"fetch_sym_proxy",
"fetch_object_proxy",
"get_innermost_proxy_mode",
"get_isolated_graphmodule",
"get_proxy_slot",
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"maybe_handle_decomp",
"proxy_call",
"set_meta",
"set_original_aten_op",
"set_proxy_slot",
"snapshot_fake",
"thunkify",
"track_tensor",
"track_tensor_tree",
"wrap_key",
"wrapper_and_args_for_make_fx",
# torch.fx.experimental.recording
"record_shapeenv_event",
"replay_shape_env_events",
"shape_env_check_state_equal",
# torch.fx.experimental.sym_node
"ceil_impl",
"floor_ceil_helper",
"floor_impl",
"method_to_operator",
"sympy_is_channels_last_contiguous_2d",
"sympy_is_channels_last_contiguous_3d",
"sympy_is_channels_last_strides_2d",
"sympy_is_channels_last_strides_3d",
"sympy_is_channels_last_strides_generic",
"sympy_is_contiguous",
"sympy_is_contiguous_generic",
"to_node",
"wrap_node",
"sym_sqrt",
# torch.fx.experimental.symbolic_shapes
"bind_symbols",
"cast_symbool_to_symint_guardless",
"create_contiguous",
"error",
"eval_guards",
"eval_is_non_overlapping_and_dense",
"expect_true",
"find_symbol_binding_fx_nodes",
"free_symbols",
"free_unbacked_symbols",
"fx_placeholder_targets",
"fx_placeholder_vals",
"guard_bool",
"guard_float",
"guard_int",
"guard_scalar",
"has_hint",
"has_symbolic_sizes_strides",
"is_channels_last_contiguous_2d",
"is_channels_last_contiguous_3d",
"is_channels_last_strides_2d",
"is_channels_last_strides_3d",
"is_contiguous",
"is_non_overlapping_and_dense_indicator",
"is_nested_int",
"is_symbol_binding_fx_node",
"is_symbolic",
# torch.fx.experimental.unification.core
"reify",
# torch.fx.experimental.unification.match
"edge",
@ -971,24 +868,6 @@ coverage_ignore_functions = [
"reverse_dict",
# torch.fx.experimental.unification.multipledispatch.variadic
"isvariadic",
# torch.fx.experimental.unification.unification_tools
"assoc",
"assoc_in",
"dissoc",
"first",
"get_in",
"getter",
"groupby",
"itemfilter",
"itemmap",
"keyfilter",
"keymap",
"merge",
"merge_with",
"update_in",
"valfilter",
"valmap",
# torch.fx.experimental.unification.utils
"freeze",
"hashable",
"raises",
@ -1429,319 +1308,8 @@ coverage_ignore_functions = [
# torch.onnx.symbolic_opset7
"max",
"min",
# torch.onnx.symbolic_opset8
"addmm",
"bmm",
"empty",
"empty_like",
"flatten",
"full",
"full_like",
"gt",
"lt",
"matmul",
"mm",
"ones",
"ones_like",
"prelu",
"repeat",
"zeros",
"zeros_like",
# torch.onnx.symbolic_opset9
"abs",
"acos",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"adaptive_max_pool1d",
"adaptive_max_pool2d",
"adaptive_max_pool3d",
"add",
"addcmul",
"addmm",
"alias",
"amax",
"amin",
"aminmax",
"arange",
"argmax",
"argmin",
"as_strided",
"as_tensor",
"asin",
"atan",
"atan2",
"avg_pool1d",
"avg_pool2d",
"avg_pool3d",
"baddbmm",
"batch_norm",
"bernoulli",
"bitwise_not",
"bitwise_or",
"bmm",
"broadcast_tensors",
"broadcast_to",
"bucketize",
"cat",
"cdist",
"ceil",
"clamp",
"clamp_max",
"clamp_min",
"clone",
"constant_pad_nd",
"contiguous",
"conv1d",
"conv2d",
"conv3d",
"conv_tbc",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
"convert_element_type",
"convolution",
"cos",
"cosine_similarity",
"cross",
"cumsum",
"detach",
"dim",
"div",
"dot",
"dropout",
"elu",
"embedding",
"embedding_bag",
"empty",
"empty_like",
"eq",
"erf",
"exp",
"expand",
"expand_as",
"eye",
"fill",
"flatten",
"floor",
"floor_divide",
"floordiv",
"frobenius_norm",
"full",
"full_like",
"gather",
"ge",
"gelu",
"get_pool_ceil_padding",
"glu",
"group_norm",
"gru",
"gt",
"hann_window",
"hardshrink",
"hardsigmoid",
"hardswish",
"hardtanh",
"index",
"index_add",
"index_copy",
"index_fill",
"index_put",
"index_select",
"instance_norm",
"is_floating_point",
"is_pinned",
"isnan",
"item",
"kl_div",
"layer_norm",
"le",
"leaky_relu",
"lerp",
"lift",
"linalg_cross",
"linalg_matrix_norm",
"linalg_norm",
"linalg_vector_norm",
"linear",
"linspace",
"log",
"log10",
"log1p",
"log2",
"log_sigmoid",
"log_softmax",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"logit",
"logsumexp",
"lstm",
"lstm_cell",
"lt",
"masked_fill",
"masked_fill_",
"matmul",
"max",
"max_pool1d",
"max_pool1d_with_indices",
"max_pool2d",
"max_pool2d_with_indices",
"max_pool3d",
"max_pool3d_with_indices",
"maximum",
"meshgrid",
"min",
"minimum",
"mish",
"mm",
"movedim",
"mse_loss",
"mul",
"multinomial",
"mv",
"narrow",
"native_layer_norm",
"ne",
"neg",
"new_empty",
"new_full",
"new_ones",
"new_zeros",
"nonzero",
"nonzero_numpy",
"noop_complex_operators",
"norm",
"numel",
"numpy_T",
"one_hot",
"ones",
"ones_like",
"onnx_placeholder",
"overload_by_arg_count",
"pad",
"pairwise_distance",
"permute",
"pixel_shuffle",
"pixel_unshuffle",
"pow",
"prelu",
"prim_constant",
"prim_constant_chunk",
"prim_constant_split",
"prim_data",
"prim_device",
"prim_dtype",
"prim_if",
"prim_layout",
"prim_list_construct",
"prim_list_unpack",
"prim_loop",
"prim_max",
"prim_min",
"prim_shape",
"prim_tolist",
"prim_tuple_construct",
"prim_type",
"prim_unchecked_cast",
"prim_uninitialized",
"rand",
"rand_like",
"randint",
"randint_like",
"randn",
"randn_like",
"reciprocal",
"reflection_pad",
"relu",
"relu6",
"remainder",
"repeat",
"repeat_interleave",
"replication_pad",
"reshape",
"reshape_as",
"rnn_relu",
"rnn_tanh",
"roll",
"rrelu",
"rsqrt",
"rsub",
"scalar_tensor",
"scatter",
"scatter_add",
"select",
"selu",
"sigmoid",
"sign",
"silu",
"sin",
"size",
"slice",
"softmax",
"softplus",
"softshrink",
"sort",
"split",
"split_with_sizes",
"sqrt",
"square",
"squeeze",
"stack",
"std",
"std_mean",
"sub",
"t",
"take",
"tan",
"tanh",
"tanhshrink",
"tensor",
"threshold",
"to",
"topk",
"transpose",
"true_divide",
"type_as",
"unbind",
"unfold",
"unsafe_chunk",
"unsafe_split",
"unsafe_split_with_sizes",
"unsqueeze",
"unsupported_complex_operators",
"unused",
"upsample_bilinear2d",
"upsample_linear1d",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"upsample_trilinear3d",
"var",
"var_mean",
"view",
"view_as",
"where",
"wrap_logical_op_with_cast_to",
"wrap_logical_op_with_negation",
"zero",
"zeros",
"zeros_like",
# torch.onnx.utils
"disable_apex_o2_state_dict_hook",
"export",
"export_to_pretty_string",
"exporter_context",
"is_in_onnx_export",
"model_signature",
"register_custom_op_symbolic",
"select_model_mode_for_export",
"setup_onnx_logging",
"unconvertible_ops",
"unpack_quantized_tensor",
"warn_on_static_input_change",
# torch.onnx.verification
"check_export_model_diff",
"verify",
"verify_aten_graph",
@ -1832,32 +1400,6 @@ coverage_ignore_functions = [
"noop_context_fn",
"set_checkpoint_early_stop",
"set_device_states",
# torch.utils.collect_env
"check_release_file",
"get_cachingallocator_config",
"get_clang_version",
"get_cmake_version",
"get_conda_packages",
"get_cpu_info",
"get_cuda_module_loading_config",
"get_cudnn_version",
"get_env_info",
"get_gcc_version",
"get_gpu_info",
"get_libc_version",
"get_lsb_version",
"get_mac_version",
"get_nvidia_driver_version",
"get_nvidia_smi",
"get_os",
"get_pip_packages",
"get_platform",
"get_pretty_env_info",
"get_python_platform",
"get_running_cuda_version",
"get_windows_version",
"is_xnnpack_available",
"pretty_str",
# torch.utils.cpp_backtrace
"get_cpp_backtrace",
# torch.utils.cpp_extension
@ -1921,52 +1463,6 @@ coverage_ignore_functions = [
"apply_shuffle_seed",
"apply_shuffle_settings",
"get_all_graph_pipes",
# torch.utils.flop_counter
"addmm_flop",
"baddbmm_flop",
"bmm_flop",
"conv_backward_flop",
"conv_flop",
"conv_flop_count",
"convert_num_with_suffix",
"get_shape",
"get_suffix_str",
"mm_flop",
"normalize_tuple",
"register_flop_formula",
"sdpa_backward_flop",
"sdpa_backward_flop_count",
"sdpa_flop",
"sdpa_flop_count",
"shape_wrapper",
"transpose_shape",
# torch.utils.hipify.hipify_python
"add_dim3",
"compute_stats",
"extract_arguments",
"file_add_header",
"file_specific_replacement",
"find_bracket_group",
"find_closure_group",
"find_parentheses_group",
"fix_static_global_kernels",
"get_hip_file_path",
"hip_header_magic",
"hipify",
"is_caffe2_gpu_file",
"is_cusparse_file",
"is_out_of_place",
"is_pytorch_file",
"is_special_file",
"match_extensions",
"matched_files_iter",
"openf",
"preprocess_file_and_save_result",
"preprocessor",
"processKernelLaunches",
"replace_extern_shared",
"replace_math_functions",
"str2bool",
# torch.utils.hooks
"unserializable_hook",
"warn_if_has_hooks",

View File

@ -12,6 +12,37 @@ These APIs are experimental and subject to change without notice.
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
```
## torch.fx.experimental.sym_node
```{eval-rst}
.. currentmodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. automodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
is_channels_last_contiguous_2d
is_channels_last_contiguous_3d
is_channels_last_strides_2d
is_channels_last_strides_3d
is_contiguous
is_non_overlapping_and_dense_indicator
method_to_operator
sympy_is_channels_last_contiguous_2d
sympy_is_channels_last_contiguous_3d
sympy_is_channels_last_strides_2d
sympy_is_channels_last_strides_3d
sympy_is_channels_last_strides_generic
sympy_is_contiguous
sympy_is_contiguous_generic
```
## torch.fx.experimental.symbolic_shapes
```{eval-rst}
@ -69,6 +100,25 @@ These APIs are experimental and subject to change without notice.
rebind_unbacked
resolve_unbacked_bindings
is_accessor_node
cast_symbool_to_symint_guardless
create_contiguous
error
eval_guards
eval_is_non_overlapping_and_dense
find_symbol_binding_fx_nodes
free_symbols
free_unbacked_symbols
fx_placeholder_targets
fx_placeholder_vals
guard_bool
guard_float
guard_int
guard_scalar
has_hint
has_symbolic_sizes_strides
is_nested_int
is_symbol_binding_fx_node
is_symbolic
```
## torch.fx.experimental.proxy_tensor
@ -91,4 +141,46 @@ These APIs are experimental and subject to change without notice.
get_proxy_mode
maybe_enable_thunkify
maybe_disable_thunkify
decompose
disable_autocast_cache
disable_proxy_modes_tracing
extract_val
fake_signature
fetch_object_proxy
fetch_sym_proxy
has_proxy_slot
is_sym_node
maybe_handle_decomp
proxy_call
set_meta
set_original_aten_op
set_proxy_slot
snapshot_fake
```
## torch.fx.experimental.unification.unification_tools
```{eval-rst}
.. currentmodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. automodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
assoc
assoc_in
dissoc
first
keyfilter
keymap
merge
merge_with
update_in
valfilter
valmap

View File

@ -1134,7 +1134,6 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.refinement_types
.. py:module:: torch.fx.experimental.rewriter
.. py:module:: torch.fx.experimental.schema_type_annotation
.. py:module:: torch.fx.experimental.sym_node
.. py:module:: torch.fx.experimental.unification.core
.. py:module:: torch.fx.experimental.unification.dispatch
.. py:module:: torch.fx.experimental.unification.match
@ -1144,7 +1143,6 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
.. py:module:: torch.fx.experimental.unification.unification_tools
.. py:module:: torch.fx.experimental.unification.utils
.. py:module:: torch.fx.experimental.unification.variable
.. py:module:: torch.fx.experimental.unify_refinements

View File

@ -134,6 +134,23 @@ Quantization to work with this as well.
ObservationType
```
## torch.ao.quantization.backend_config.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.backend_config.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
entry_to_pretty_str
pattern_to_human_readable
remove_boolean_dispatch_from_name
```
## torch.ao.quantization.fx.custom_config
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
@ -154,6 +171,30 @@ This module contains a few CustomConfig classes that's used in both eager mode a
StandaloneModuleConfigEntry
```
## torch.ao.quantization.fx.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.fx.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
all_node_args_except_first
all_node_args_have_no_tensors
collect_producer_nodes
create_getattr_from_value
create_node_from_old_node_preserve_meta
graph_module_from_producer_nodes
maybe_get_next_module
node_arg_is_bias
node_arg_is_weight
return_arg_list
```
## torch.ao.quantization.quantizer
```{eval-rst}

View File

@ -19,6 +19,91 @@
swap_tensors
```
# torch.utils.collect_env
```{eval-rst}
.. automodule:: torch.utils.collect_env
```
```{eval-rst}
.. currentmodule:: torch.utils.collect_env
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
check_release_file
is_xnnpack_available
pretty_str
```
# torch.utils.flop_counter
```{eval-rst}
.. automodule:: torch.utils.flop_counter
```
```{eval-rst}
.. currentmodule:: torch.utils.flop_counter
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
baddbmm_flop
bmm_flop
conv_backward_flop
conv_flop
conv_flop_count
register_flop_formula
sdpa_backward_flop
sdpa_backward_flop_count
sdpa_flop
sdpa_flop_count
shape_wrapper
```
# torch.utils.hipify.hipify_python
```{eval-rst}
.. automodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. currentmodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
compute_stats
extract_arguments
file_add_header
file_specific_replacement
find_bracket_group
find_closure_group
find_parentheses_group
fix_static_global_kernels
hip_header_magic
hipify
is_caffe2_gpu_file
is_cusparse_file
is_out_of_place
is_pytorch_file
is_special_file
openf
preprocess_file_and_save_result
preprocessor
processKernelLaunches
replace_extern_shared
replace_math_functions
str2bool
```
<!-- This module needs to be documented. Adding here in the meantime
for tracking purposes -->
```{eval-rst}
@ -43,7 +128,6 @@ for tracking purposes -->
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
.. py:module:: torch.utils.bundled_inputs
.. py:module:: torch.utils.checkpoint
.. py:module:: torch.utils.collect_env
.. py:module:: torch.utils.cpp_backtrace
.. py:module:: torch.utils.cpp_extension
.. py:module:: torch.utils.data.backward_compatibility
@ -80,10 +164,8 @@ for tracking purposes -->
.. py:module:: torch.utils.data.sampler
.. py:module:: torch.utils.dlpack
.. py:module:: torch.utils.file_baton
.. py:module:: torch.utils.flop_counter
.. py:module:: torch.utils.hipify.constants
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
.. py:module:: torch.utils.hipify.hipify_python
.. py:module:: torch.utils.hipify.version
.. py:module:: torch.utils.hooks
.. py:module:: torch.utils.jit.log_extract

View File

@ -184,7 +184,6 @@ ignore = [
"TC006",
# TODO: Remove Python-3.10 specific suppressions
"B905",
"UP035",
]
select = [
"B",

View File

@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
def mirror_inductor_external_kernels() -> None:
"""
Copy external kernels into Inductor so they are importable.
"""
paths = [
(
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
CWD
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
),
]
for new_path, orig_path in paths:
# Create the dirs involved in new_path if they don't exist
if not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location
if orig_path.is_file():
shutil.copyfile(orig_path, new_path)
continue
if orig_path.is_dir():
if new_path.exists():
# copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path)
continue
raise RuntimeError(
"Check the file paths in `mirror_inductor_external_kernels()`"
)
# ATTENTION: THIS IS AI SLOP
def extract_variant_from_version(version: str) -> str:
"""Extract variant from version string, defaulting to 'cpu'."""
@ -1615,6 +1646,7 @@ def main() -> None:
mirror_files_into_torchgen()
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
(
ext_modules,
@ -1649,6 +1681,7 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_inductor/kernel/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: unknown"]
import os
import tempfile
from backend import get_custom_backend_library_path, Model, to_custom_backend
@ -41,14 +40,11 @@ class TestCustomBackend(TestCase):
self.test_execute()
# Save and load.
f = tempfile.NamedTemporaryFile(delete=False)
try:
with tempfile.NamedTemporaryFile() as f:
f.close()
torch.jit.save(self.model, f.name)
loaded = torch.jit.load(f.name)
finally:
os.unlink(f.name)
self.model = loaded
self.model = loaded
# Test execution again.
self.test_execute()

View File

@ -1,6 +1,5 @@
# Owner(s): ["module: unknown"]
import os.path
import sys
import tempfile
import unittest
@ -144,16 +143,13 @@ def forward(self, arg0_1):
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually.
file = tempfile.NamedTemporaryFile(delete=False)
try:
with tempfile.NamedTemporaryFile() as file:
file.close()
model.save(file.name)
loaded = torch.jit.load(file.name)
finally:
os.unlink(file.name)
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
output = loaded.forward(torch.ones(5))
self.assertTrue(output.allclose(torch.ones(5) + 1))
if __name__ == "__main__":

View File

@ -204,14 +204,16 @@ class DistConvolutionOpsTest(DTensorTestBase):
self.assertTrue(b_dt.grad is not None)
self.assertTrue(x_dt.grad is None)
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
def _run_single_arg_fwd(
self, model, arg, placements=None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
device_mesh = self.build_device_mesh()
model_copy = copy.deepcopy(model).to(device=self.device_type)
dist_model = distribute_module(model, device_mesh, _conv_fn)
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
arg_dt = DTensor.from_local(arg, device_mesh, placements)
out_dt = dist_model(arg_dt.to(device=self.device_type))
out = model_copy(arg)
out = model_copy(arg_dt.full_tensor())
return (out_dt.full_tensor(), out)
@with_comms
@ -219,22 +221,20 @@ class DistConvolutionOpsTest(DTensorTestBase):
model = nn.Conv1d(64, 64, 3, padding=1)
x = torch.randn(1, 64, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt.shape, out.shape)
self.assertEqual(out_dt, out)
@with_comms
def test_conv3d(self):
model = nn.Conv3d(64, 64, 3, padding=1)
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt.shape, out.shape)
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
self.assertEqual(out_dt, out)
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
DistConvolutionOpsTest,
# Send / recv ops are not supported
skipped_tests=[
"test_conv1d",
"test_conv3d",
"test_conv_backward_none_grad_inp",
"test_depthwise_convolution",
"test_downsampling_convolution",

View File

@ -2,7 +2,6 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import copy
import itertools
import unittest
@ -22,9 +21,8 @@ from torch.distributed.tensor import (
)
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard
from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -35,7 +33,11 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
generate_shard_orders,
make_full_tensor,
map_local_tensor_for_rank,
patched_distribute_tensor as _distribute_tensor,
redistribute,
with_comms,
)
from torch.utils._debug_mode import DebugMode
@ -785,88 +787,6 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
else:
return ""
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def redistribute(
self,
dtensor_input,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""
wrapper function to support shard_order for redistribution
This is a simpler version of Redistribute, only considers the forward.
"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
old_spec = dtensor_input._spec
new_spec = copy.deepcopy(old_spec)
new_spec.placements = placements
if shard_order is not None:
new_spec.shard_order = shard_order
else:
new_spec.shard_order = ()
if old_spec == new_spec:
return dtensor_input
dtensor_input = DTensor.from_local(
redistribute_local_tensor(
dtensor_input.to_local(),
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)
return dtensor_input # returns DTensor
# TODO(zpcore): remove once the native distribute_tensor supports
# shard_order arg
def distribute_tensor(
self,
input_tensor,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""wrapper function to support shard_order for tensor distribution"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
# fix the shard order
return self.redistribute(
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def full_tensor(self, dtensor_input):
"""wrapper function to support DTensor.full_tensor"""
return self.redistribute(
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
).to_local()
def _shard_order_to_placement(self, shard_order, mesh):
"""convert shard_order to placement with only Replicate() and Shard()"""
placements = [Replicate() for _ in range(mesh.ndim)]
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
"""Convert shard_order dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
for tensor_dim, mesh_dims in shard_order.items()
)
@with_comms
def test_ordered_redistribute(self):
"""Test ordered redistribution with various sharding syntaxes"""
@ -927,13 +847,11 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
sharding_src_dst_pairs_with_expected_trace
):
sharded_dt = self.distribute_tensor(
sharded_dt = _distribute_tensor(
input_data.clone(), mesh, src_placement, shard_order=src_order
)
with DebugMode(record_torchfunction=False) as debug_mode:
sharded_dt = self.redistribute(
sharded_dt, mesh, dst_placement, dst_order
)
sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order)
trace_str = self._extract_redistribute_trace_from_debug_mode(
debug_mode.debug_string()
)
@ -957,49 +875,11 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
trace_str,
"""S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
)
expected_dt = self.distribute_tensor(
expected_dt = _distribute_tensor(
input_data.clone(), mesh, dst_placement, shard_order=dst_order
)
self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
def generate_shard_orders(self, mesh, tensor_rank):
# Generate all possible sharding placement of tensor with rank
# `tensor_rank` over mesh.
def _split_list(lst: list, N: int):
def compositions(n, k):
if k == 1:
yield [n]
else:
for i in range(1, n - k + 2):
for tail in compositions(n - i, k - 1):
yield [i] + tail
length = len(lst)
for comp in compositions(length, N):
result = []
start = 0
for size in comp:
result.append(lst[start : start + size])
start += size
yield result
all_mesh = list(range(mesh.ndim))
all_device_order = list(itertools.permutations(all_mesh))
for device_order in all_device_order:
# split on device orders, and assign each device order segment to a tensor dim
for num_split in range(1, mesh.ndim + 1):
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
for tensor_dims in itertools.combinations(
range(tensor_rank), len(splitted_list)
):
shard_order = {}
assert len(tensor_dims) == len(splitted_list)
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
shard_order[tensor_dim] = device_order[
mesh_dims[0] : mesh_dims[-1] + 1
]
yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
@with_comms
def test_generate_shard_orders(self):
"""Check if `generate_shard_orders` generates unique sharding combinations"""
@ -1012,7 +892,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
]
for test_input in test_inputs:
all_combinations = []
for shard_order in self.generate_shard_orders(
for shard_order in generate_shard_orders(
test_input["mesh"], test_input["tensor_rank"]
):
all_combinations.append(shard_order) # noqa: PERF402
@ -1062,12 +942,12 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
input_data = torch.randn(tensor_shape, device=self.device_type)
tensor_rank = input_data.ndim
with maybe_disable_local_tensor_mode():
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
shard_orders = generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
sharded_dt = self.distribute_tensor(
sharded_dt = _distribute_tensor(
input_data.clone(), mesh, placements=None, shard_order=shard_order
)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
# 2. Verify the correctness of redistribution from DTensor to DTensor.
# This test repeatedly redistributes a DTensor to various ordered
@ -1078,20 +958,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
tensor_rank = input_data.ndim
prev_sharded_dt = None
with maybe_disable_local_tensor_mode():
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
shard_orders = generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
if prev_sharded_dt is None:
prev_sharded_dt = self.distribute_tensor(
prev_sharded_dt = _distribute_tensor(
input_data.clone(),
mesh,
placements=None,
shard_order=shard_order,
)
else:
sharded_dt = self.redistribute(
sharded_dt = redistribute(
prev_sharded_dt, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
prev_sharded_dt = sharded_dt
@with_comms
@ -1136,13 +1016,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
local_tensor = torch.randn(shape, device=self.device_type)
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
with maybe_disable_local_tensor_mode():
shard_orders = self.generate_shard_orders(mesh, len(shape))
shard_orders = generate_shard_orders(mesh, len(shape))
for shard_order in shard_orders:
sharded_dt = self.redistribute(
sharded_dt = redistribute(
full_tensor, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(
self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
make_full_tensor(sharded_dt), make_full_tensor(full_tensor)
)
@unittest.skip(
@ -1152,24 +1032,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
@with_comms
def test_ordered_redistribute_for_special_placement(self):
"""Test ordered redistribution with special placement"""
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
torch.manual_seed(21)
mesh = init_device_mesh(self.device_type, (8,))
input_data = torch.randn((8, 8), device=self.device_type)
src_placement = [Shard(1)]
tgt_placement = [
(_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
(MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
]
sharded_dt = self.distribute_tensor(
sharded_dt = _distribute_tensor(
input_data.clone(),
mesh,
src_placement,
shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
)
sharded_dt = self.redistribute(
sharded_dt, mesh, tgt_placement, shard_order=None
)
sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None)
@with_comms
def test_shard_order_same_data_as_strided_shard(self):
@ -1179,7 +1055,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
# specify right-to-left order use ordered shard
x_ordered_dt = self.distribute_tensor(
x_ordered_dt = _distribute_tensor(
x,
device_mesh,
placements=[Shard(0), Shard(0)],

View File

@ -34,6 +34,10 @@ from torch.distributed.tensor.placement_types import (
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
generate_shard_orders,
LocalDTensorTestBase,
patched_distribute_tensor as _distribute_tensor,
shard_order_to_placement,
with_comms,
)
@ -774,6 +778,63 @@ class TestStridedSharding(DTensorTestBase):
self.assertEqual(dtensor.full_tensor(), tensor)
class Test_StridedShard_with_shard_order(LocalDTensorTestBase):
@property
def world_size(self) -> int:
return 32
@with_comms
def test_StridedShard_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2))
shard_iter = generate_shard_orders(mesh, 3)
# It takes ~4.8h to complete total 2520 shard order combinations here
# using LocalTensor. So we only randomly pick 25 shard orders to test.
all_shard_order = list(shard_iter)
import random
random.seed(42)
shard_order_choices = random.sample(
all_shard_order, min(25, len(all_shard_order))
)
x = torch.randn(32, 32, 32)
for shard_order in shard_order_choices:
a = _distribute_tensor(x, mesh, None, shard_order)
placement_without_stridedshard = shard_order_to_placement(
shard_order, mesh
)
placements_with_stridedshard = (
DTensorSpec._convert_shard_order_to_StridedShard(
shard_order, placement_without_stridedshard, mesh
)
)
b = distribute_tensor(x, mesh, placements_with_stridedshard)
shard_order_from_stridedshard = (
DTensorSpec._maybe_convert_StridedShard_to_shard_order(
placements_with_stridedshard, mesh
)
)
self.assertEqual(shard_order, shard_order_from_stridedshard)
self.assertEqual(a.to_local(), b.to_local())
@with_comms
def test_StridedShard_not_convertible_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8))
unconvertible_placements_list = [
[_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)],
[_StridedShard(0, split_factor=2), Shard(1)],
[_StridedShard(1, split_factor=16), Shard(1)],
]
for placements in unconvertible_placements_list:
shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order(
tuple(placements), mesh
)
self.assertIsNone(shard_order)
class Test2DStridedLocalShard(DTensorTestBase):
@property
def world_size(self):

View File

@ -861,7 +861,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
def test_logs_out(self):
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as tmp:
with tempfile.NamedTemporaryFile(delete=True) as tmp:
file_path = _as_posix_path(tmp.name)
"""
NamedTemporaryFile will include a file open operation.
@ -888,10 +888,6 @@ fn(torch.randn(5))
file_path, encoding="utf-8"
) as fd: # encoding file to UTF-8 for Windows.
lines = fd.read()
fd.close()
os.remove(
file_path
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
orig_maxDiff = unittest.TestCase.maxDiff
unittest.TestCase.maxDiff = None
try:

View File

@ -2,7 +2,6 @@
import copy
import pathlib
import tempfile
import unittest
@ -97,55 +96,55 @@ def run_with_nativert(ep):
MODEL_NAME = "forward"
# TODO Does named tempfile have collision?
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
torch.export.pt2_archive._package.package_pt2(
f, exported_programs={MODEL_NAME: ep_infer}
)
filename = f.name
try:
ep_args, ep_kwargs = ep_infer.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
)
torch.manual_seed(0)
try:
flat_expected = pytree.tree_leaves(
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
ep_args, ep_kwargs = ep_infer.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
)
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
torch.manual_seed(0)
try:
flat_expected = pytree.tree_leaves(
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
)
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
model_runner = PyModelRunner(filename, MODEL_NAME)
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
else:
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
model_runner = PyModelRunner(filename, MODEL_NAME)
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
else:
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
finally:
pathlib.Path(filename).unlink(missing_ok=True)
return ep
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(
expected, torch.Tensor
):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
else:
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
return ep
def mocked_nativert_export_strict(*args, **kwargs):
@ -287,7 +286,7 @@ class TestNativeRT(TestCase):
)
# package everything needed for the NativeRT to execute the AOTI delegate
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
package_nativert_with_aoti_delegate(
f,
MODEL_NAME,
@ -298,50 +297,48 @@ class TestNativeRT(TestCase):
)
filename = f.name
try:
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
)
torch.manual_seed(0)
try:
flat_expected = pytree.tree_leaves(
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
ep_args_copied, ep_kwargs_copied = (
copy.deepcopy(ep_args),
copy.deepcopy(ep_kwargs),
)
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
torch.manual_seed(0)
try:
flat_expected = pytree.tree_leaves(
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
)
except Exception as e:
raise unittest.case.SkipTest(str(e)) from e
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
else:
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(
expected, torch.Tensor
):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
torch.manual_seed(0)
if _is_supported_types((ep_args, ep_kwargs)):
results = model_runner.run(*ep_args, **ep_kwargs)
else:
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
finally:
pathlib.Path(filename).unlink(missing_ok=True)
results = model_runner.run_with_flat_inputs_and_outputs(
*pytree.tree_leaves((ep_args, ep_kwargs))
)
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(
expected, torch.Tensor
):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
assert result.device == expected.device
torch.testing.assert_close(result, expected, equal_nan=True)
else:
assert result == expected
except RuntimeError as e:
# User need to register pytree type on the cpp side, which
# cannot be tested in python unittest.
if "Unknown pytree node type" in str(e):
pass
else:
raise e
if is_fbcode():

View File

@ -206,6 +206,10 @@ class TestPyCodeCache(TestCase):
.decode()
.strip()
)
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
if torch.xpu.is_available():
wrapper_path = wrapper_path.splitlines()[-1]
hit = hit.splitlines()[-1]
self.assertEqual(hit, "1")
with open(wrapper_path) as f:

View File

@ -0,0 +1,154 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch import Tensor
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import ensure_cute_available
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
@unittest.skipIf(
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
"CuTeDSL library or Blackwell device not available",
)
@instantiate_parametrized_tests
class TestCuTeDSLGroupedGemm(InductorTestCase):
def _get_inputs(
self,
group_size: int,
M_hint: int,
K: int,
N: int,
device: str,
dtype: torch.dtype,
alignment: int = 16,
) -> tuple[Tensor, Tensor, Tensor]:
# --- Random, tile-aligned M sizes ---
M_sizes = (
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
* alignment
)
M_total = torch.sum(M_sizes).item()
# --- Construct input tensors ---
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
# --- Build offsets (no leading zero, strictly increasing) ---
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
return (A, B, offsets)
@parametrize("group_size", (2, 8))
@parametrize("M_hint", (256, 1024))
@parametrize("K", (64, 128))
@parametrize("N", (128, 256))
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
device = "cuda"
dtype = torch.bfloat16
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# Eager execution
c_eager = grouped_gemm_fn(A, B, offsets)
# Test with Cute backend
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
@parametrize("layout_B", ("contiguous", "broadcasted"))
def test_grouped_gemm_assorted_layouts(
self,
layout_A: str,
layout_B: str,
):
device = "cuda"
dtype = torch.bfloat16
G, K, N = 8, 64, 128
M_sizes = [128] * G
sum_M = sum(M_sizes)
offsets = torch.tensor(
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
)
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
A = A_base
if layout_A == "offset":
# allocate bigger buffer than needed, use nonzero storage offset
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
offset = 128 # skip first 128 elements
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
elif layout_A == "padded":
# simulate row pitch > K (row_stride = K + pad)
row_pitch = K + 8
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
elif layout_A == "view":
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
A = A_storage.view(sum_M, K)
assert A._base is not None
assert A.shape == (sum_M, K)
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
if layout_B == "broadcasted":
# Broadcast B across groups (zero stride along G)
B = B[0].expand(G, K, N)
assert B.stride(0) == 0
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# --- eager ---
c_eager = grouped_gemm_fn(A, B, offsets)
# --- compiled (CUTE backend) ---
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
if __name__ == "__main__":
run_tests()

View File

@ -3,7 +3,8 @@
import functools
import weakref
from collections import Counter
from typing import Callable, Optional
from collections.abc import Callable
from typing import Optional
import torch
from torch._inductor.fx_passes.memory_estimator import (
@ -28,7 +29,7 @@ def device_filter(device):
class FakeTensorMemoryProfilerMode(TorchDispatchMode):
def __init__(self, device_filter: Optional[Callable[torch.device, bool]] = None):
def __init__(self, device_filter: Optional[Callable[[torch.device], bool]] = None):
# counter of storage ids to live references
self.storage_count: dict[int, int] = Counter()
# live fake tensors

View File

@ -482,8 +482,8 @@ class TestExecutionTrace(TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
@unittest.skipIf(
(not has_triton()) or (not TEST_CUDA and not TEST_XPU),
"need triton and device(CUDA or XPU) availability to run",
(not has_triton()) or (not TEST_CUDA),
"need triton and device CUDA availability to run",
)
@skipCPUIf(True, "skip CPU device for testing profiling triton")
def test_triton_fx_graph_with_et(self, device):

View File

@ -2005,6 +2005,10 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
report = json.load(f)
self._validate_basic_json(report["traceEvents"], with_cuda)
@unittest.skipIf(
torch.xpu.is_available(),
"XPU Trace event ends too late! Refer https://github.com/intel/torch-xpu-ops/issues/2263",
)
@unittest.skipIf(not kineto_available(), "Kineto is required")
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
def test_basic_chrome_trace(self):
@ -2158,7 +2162,10 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
def test_basic_profile(self):
# test a really basic profile to make sure no erroneous aten ops are run
x = torch.randn(4, device="cuda")
acc = torch.accelerator.current_accelerator()
self.assertIsNotNone(acc)
device = acc.type
x = torch.randn(4, device=device)
with torch.profiler.profile(with_stack=True) as p:
x *= 2
names = [e.name for e in p.events()]
@ -2225,6 +2232,7 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
@unittest.skipIf(
torch.cuda.is_available(), "CUDA complains about forking after init"
)
@unittest.skipIf(torch.xpu.is_available(), "XPU complains about forking after init")
@unittest.skipIf(IS_WINDOWS, "can't use os.fork() on Windows")
def test_forked_process(self):
# Induce a pid cache by running the profiler with payload

View File

@ -263,13 +263,7 @@ S390X_BLOCKLIST = [
XPU_BLOCKLIST = [
"test_autograd",
"profiler/test_cpp_thread",
"profiler/test_execution_trace",
"profiler/test_memory_profiler",
"profiler/test_profiler",
"profiler/test_profiler_tree",
"profiler/test_record_function",
"profiler/test_torch_tidy",
"test_openreg",
]

View File

@ -1,245 +1,236 @@
{
"EndToEndLSTM (__main__.RNNTest)": 207.89400227864584,
"MultiheadAttention (__main__.ModulesTest)": 141.1396687825521,
"test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 214.02366638183594,
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 77.26125049591064,
"test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 116.37000020345052,
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 69.25722334120009,
"test_after_aot_gpu_runtime_error (__main__.MinifierIsolateTests)": 65.84466807047527,
"test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 178.41399637858072,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.55014337812151,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 122.18047623407273,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 192.6405719575428,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.27904801141648,
"test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.906999588012695,
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 62.244998931884766,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 150.04100036621094,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 191.85050201416016,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.9276631673177,
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.31450271606445,
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 125.24066416422527,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.47783279418945,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.46250025431316,
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1031.0534973144531,
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 239.67400105794272,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 495.0447726779514,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 490.18524169921875,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 144.06477737426758,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 342.20416259765625,
"test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 62.01366678873698,
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 71.07200050354004,
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 73.9221674601237,
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 226.0122528076172,
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 144.97249857584634,
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 303.20537185668945,
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 386.0518798828125,
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 291.2442270914714,
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 95.87866719563802,
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 98.38716634114583,
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 69.08016649881999,
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 69.88233311971028,
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 104.17599995930989,
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 97.41800308227539,
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 474.6719970703125,
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 440.4375,
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 293.3983332316081,
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 238.7328338623047,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1218.4906717936199,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.73516782124837,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1156.0123494466145,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.13916714986165,
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.90450032552083,
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.42100016276042,
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.98883310953777,
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 73.34433364868164,
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.38016573588053,
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.52783330281575,
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 111.06333287556966,
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 110.19833374023438,
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.10083134969075,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.23766644795736,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 70.18666712443034,
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 62.61399841308594,
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 67.7816670735677,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 121.6183344523112,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 107.30266698201497,
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 130.8143310546875,
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 127.27633412679036,
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 303.55183664957684,
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 234.41216532389322,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 85.3436673482259,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 80.9688326517741,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 82.55149968465169,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.37966791788737,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 129.88233184814453,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 129.4015007019043,
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1282.3826497395833,
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1270.64599609375,
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1297.9046630859375,
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 545.2034962972006,
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 572.5616760253906,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 64.40316645304362,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.68383344014485,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.48333422342936,
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 61.959999084472656,
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 105.79100036621094,
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 122.34666570027669,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 68.7205015818278,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.2183329264323,
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.86883227030437,
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 77.48183314005534,
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 79.1564998626709,
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 160.41250228881836,
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 79.10633341471355,
"test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 60.106833140055336,
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 221.3586196899414,
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 504.3203754425049,
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 78.03233337402344,
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 152.302001953125,
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 152.99433390299478,
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 96.25399971008301,
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 75.70275068283081,
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 139.14399747674665,
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 72.7847490310669,
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 91.59966786702473,
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 87.57833353678386,
"test_count_nonzero_all (__main__.TestBool)": 664.9986343383789,
"test_cp_flex_attention_document_mask (__main__.CPFlexAttentionTest)": 78.31500244140625,
"test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 385.24249792099,
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.70466740926106,
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 685.0679931640625,
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 86.26266733805339,
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 292.93699645996094,
"test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 66.84199905395508,
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 69.56212568283081,
"test_fail_creation_ops.py (__main__.TestTyping)": 69.80560022989908,
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 73.36666552225749,
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 90.40366744995117,
"test_fuse_large_params_cpu (__main__.CpuTests)": 132.73199844360352,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 150.16662406921387,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 159.28499794006348,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 165.19283294677734,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 151.12366739908853,
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.61699930826823,
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.00600179036458,
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.3759994506836,
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 190.89249674479166,
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 149.6598358154297,
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 146.07766723632812,
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 532.8139902750651,
"test_graph_partition_refcount_cuda (__main__.GPUTests)": 69.78400001525878,
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 267.04988850487604,
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 273.54955800374347,
"test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.84733072916666,
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 326.0143330891927,
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.96037435531616,
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 77.44933319091797,
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 126.81488884819879,
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 118.70199839274089,
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 129.20266723632812,
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 97.18800099690755,
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 130.3183339436849,
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 140.43233235677084,
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 293.122774971856,
"test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 63.835832277933754,
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.77049922943115,
"test_lstm_cpu (__main__.TestMkldnnCPU)": 100.89649963378906,
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 140.07424926757812,
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 72.90299733479817,
"test_max_autotune_addmm_search_space_EXHAUSTIVE_dynamic_True (__main__.TestMaxAutotuneSubproc)": 82.62433369954427,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 87.51499938964844,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_True_use_aoti_True (__main__.TestCKBackend)": 71.22416591644287,
"test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 424.50966389973956,
"test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 134.14600626627603,
"test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 358.88099161783856,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.58866712782118,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 62.68674945831299,
"test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 65.85794713936355,
"test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTest)": 103.6923344930013,
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTest)": 187.6953328450521,
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 370.27442932128906,
"test_proper_exit (__main__.TestDataLoader)": 227.83111148410373,
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 227.1901126437717,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.52099990844727,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.50249862670898,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 92.52400207519531,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.75499725341797,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 107.40500259399414,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 83.80450057983398,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 107.46599833170573,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.65650177001953,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 83.4114990234375,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.47100067138672,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 108.55533345540364,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 89.23666381835938,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.13900375366211,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 100.14550018310547,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.33649826049805,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.08150100708008,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 97.59600067138672,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 104.82933553059895,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 114.43099721272786,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 110.40333302815755,
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 567.2765197753906,
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1032.5083312988281,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 852.7170003255209,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1361.954854329427,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 77.385498046875,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 265.0193354288737,
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 115.31749725341797,
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 245.27666727701822,
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 71.75300216674805,
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 141.8895009358724,
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 71.15749994913737,
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 90.59066772460938,
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 173.73916625976562,
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.65066655476888,
"test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 99.21799850463867,
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 90.86299896240234,
"test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.57050196329753,
"test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 69.65149958928426,
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 78.13350168863933,
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 76.85255601671007,
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 333.04866282145184,
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 146.96599833170572,
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 160.4881100124783,
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 124.10055626763238,
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 117.38410907321506,
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 710.2327779134115,
"test_sort_stable_cpu (__main__.CpuTritonTests)": 1324.4399820963542,
"test_sort_stable_cuda (__main__.GPUTests)": 76.83109970092774,
"test_split_cumsum_cpu (__main__.CpuTritonTests)": 88.58433532714844,
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 160.1271684964498,
"test_tensor_split (__main__.TestVmapOperators)": 79.18955569393519,
"test_terminate_handler_on_crash (__main__.TestTorch)": 111.30388899644215,
"test_terminate_signal (__main__.ForkTest)": 132.3458870516883,
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 132.2043343567186,
"test_terminate_signal (__main__.SpawnTest)": 136.1005539894104,
"test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 76.20899939537048,
"test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 63.82099969046457,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 61.925000508626304,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 60.89849980672201,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 66.88233375549316,
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.9854990641276,
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.4044977823893,
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 108.19166437784831,
"test_unary_ops (__main__.TestTEFuserDynamic)": 96.32655514611139,
"test_unary_ops (__main__.TestTEFuserStatic)": 105.33362591266632,
"test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 97.8336664835612,
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 82.86566925048828,
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 68.26500002543132,
"test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 97.1120007832845,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 88.24766794840495,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 65.41266759236653,
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 74.75533294677734,
"test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 73.52500089009602,
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.85466639200847,
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 98.39650090535481,
"test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 61.39695285615467,
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 77.88249842325847,
"test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.0695006052653,
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 81.86250114440918,
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 98.63116455078125,
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 94.85683314005534,
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 173.00183614095053
"EndToEndLSTM (__main__.RNNTest)": 190.48799641927084,
"MultiheadAttention (__main__.ModulesTest)": 141.2663370768229,
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 82.87333234151204,
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 70.6538565499442,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 123.34033711751302,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 171.25450134277344,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 119.71899922688802,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.35733322870163,
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.64533233642578,
"test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.672952016194664,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 138.04000091552734,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 172.1344985961914,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 114.02050018310547,
"test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.25642830984933,
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.3350003560384,
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 120.95249938964844,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.97774887084961,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.90774917602539,
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1144.3935089111328,
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 222.58500061035156,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.10033162434894,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 517.1875050862631,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 113.88125228881836,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 235.77350616455078,
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.6155014038086,
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.63325119018555,
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 216.2968317667643,
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 153.0915012359619,
"test_cat_2k_args (__main__.TestTEFuserDynamic)": 108.80471753561869,
"test_cat_2k_args (__main__.TestTEFuserStatic)": 102.20949847949669,
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 311.7026621500651,
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 395.0001729329427,
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 348.6218566894531,
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 98.71574974060059,
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 97.68499946594238,
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.0557508468628,
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 65.86899948120117,
"test_comprehensive_gradient_cuda_complex64 (__main__.TestDecompCUDA)": 97.15880012512207,
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 103.20700073242188,
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 102.74033610026042,
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 460.4286702473958,
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 435.62066650390625,
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 287.3090057373047,
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 265.1860008239746,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1235.7365112304688,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.20825004577637,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1281.2615051269531,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.90750026702881,
"test_comprehensive_linalg_householder_product_cuda_complex64 (__main__.TestDecompCUDA)": 79.04633331298828,
"test_comprehensive_linalg_lu_factor_ex_cuda_complex128 (__main__.TestDecompCUDA)": 68.10879821777344,
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.43025207519531,
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.94575023651123,
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.93649864196777,
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.46275043487549,
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 64.10650062561035,
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.03124904632568,
"test_comprehensive_linalg_svd_cuda_float64 (__main__.TestDecompCUDA)": 64.32800025939942,
"test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 96.41353665865384,
"test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 100.17661388103778,
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 110.95025062561035,
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 108.06550025939941,
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 104.24150085449219,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.453749656677246,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 61.739999771118164,
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 69.96549987792969,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 113.65749931335449,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 106.57500076293945,
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 117.54049682617188,
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 116.19766489664714,
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 272.48475646972656,
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 248.12175369262695,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 79.66900062561035,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 81.52649879455566,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 79.29400062561035,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.40349960327148,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 128.42924880981445,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 125.03675079345703,
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1264.9732360839844,
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1250.7332458496094,
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1255.0684814453125,
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 574.4627532958984,
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 581.7282485961914,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 65.052001953125,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 61.19200134277344,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 63.16874885559082,
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 62.39250183105469,
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 113.32574844360352,
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 113.91499900817871,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 74.42549800872803,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 76.1560001373291,
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.76750087738037,
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 70.69724941253662,
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 69.87625026702881,
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 80.2542495727539,
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 69.0419979095459,
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 117.03342655726841,
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 289.50213841029574,
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 67.38800048828125,
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 145.27399444580078,
"test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 66.9245999654134,
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 151.91099548339844,
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 92.79549789428711,
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.60149955749512,
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 69.27724676392972,
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 76.24971498761859,
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 81.93449974060059,
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.87700080871582,
"test_count_nonzero_all (__main__.TestBool)": 631.2585144042969,
"test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 61.042999267578125,
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.49850082397461,
"test_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 93.03299713134766,
"test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 228.46711820714614,
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 286.29998779296875,
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 68.43842806134906,
"test_fail_random.py (__main__.TestTyping)": 74.83523060725285,
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 72.84900093078613,
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 75.86675071716309,
"test_fuse_large_params_cpu (__main__.CpuTests)": 151.4199981689453,
"test_fuse_large_params_cuda (__main__.GPUTests)": 60.351999282836914,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 158.3622828892299,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 149.6796646118164,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.97800064086914,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 114.8385009765625,
"test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 84.69736822027909,
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.62700080871582,
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 89.197998046875,
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 96.46900177001953,
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 187.83824920654297,
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.49449920654297,
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 124.90424919128418,
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 518.4157485961914,
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 304.6440022786458,
"test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 143.82052836698645,
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 77.4985705784389,
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 76.06225109100342,
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 138.9222858973912,
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 120.62233225504558,
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 148.1219940185547,
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 109.34200286865234,
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 119.36233266194661,
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 127.95700073242188,
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 61.64850175380707,
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 105.3174296787807,
"test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 585.9210001627604,
"test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 504.3250020345052,
"test_lstm_cpu (__main__.TestMkldnnCPU)": 86.21566645304362,
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 129.277715410505,
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 64.24800109863281,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 77.23899841308594,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_True (__main__.TestCKBackend)": 65.15649795532227,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 62.579833984375,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.6555004119873,
"test_pattern_matcher_multi_user_cpu (__main__.CpuTritonTests)": 142.21566772460938,
"test_proper_exit (__main__.TestDataLoader)": 267.74214717320035,
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 266.6539971487863,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 101.97100067138672,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.3346659342448,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.50300216674805,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 104.61333465576172,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.41133371988933,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.37100219726562,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.30900065104167,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.61750030517578,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 79.33600234985352,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 101.2393315633138,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 103.18400192260742,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.4114990234375,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.52833302815755,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.72700119018555,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 100.61966705322266,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.2750015258789,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.17449951171875,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.96749877929688,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 106.44049835205078,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 101.7173334757487,
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 531.5236612955729,
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1077.4210205078125,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 812.0880126953125,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1347.9365234375,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 88.93533070882161,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 269.01949310302734,
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.99799601236978,
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 232.36275100708008,
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 69.80400085449219,
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 134.3415012359619,
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.51749992370605,
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 91.21066792805989,
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 170.97775268554688,
"test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.608266321818036,
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.62575149536133,
"test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 63.59499969482422,
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 88.68299865722656,
"test_rnn_decomp_module_nn_LSTM_train_mode_cuda_float32 (__main__.TestDecompCUDA)": 91.50320053100586,
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 66.10774898529053,
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 66.20533180236816,
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 243.1092529296875,
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 105.01200103759766,
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.93685695103237,
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 142.38899993896484,
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 119.90166600545247,
"test_sort_bool_cpu (__main__.CpuTritonTests)": 346.2856750488281,
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 423.09974098205566,
"test_sort_stable_cuda (__main__.GPUTests)": 117.61659927368164,
"test_sort_transpose_cpu (__main__.CpuTritonTests)": 378.31200154622394,
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 222.822007894516,
"test_terminate_handler_on_crash (__main__.TestTorch)": 143.31728431156702,
"test_terminate_signal (__main__.ForkTest)": 168.20485967184817,
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 168.19242484867573,
"test_terminate_signal (__main__.SpawnTest)": 172.16428443363733,
"test_thnn_conv_strided_padded_dilated (__main__.TestConvolutionNN)": 93.30639710426331,
"test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 163.89743041992188,
"test_train_parity_with_activation_checkpointing (__main__.TestFullyShard1DTrainingCompose)": 60.47671399797712,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 63.39550018310547,
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 173.53924942016602,
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 175.3212537765503,
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.20649909973145,
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 99.9885025024414,
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 71.64024829864502,
"test_view_ops (__main__.TestViewOpsWithLocalTensor)": 73.45887422561646,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.75249862670898,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 61.858001708984375,
"test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 65.11023766653878,
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 66.35274982452393,
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 61.196499824523926,
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.75380906604585,
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 73.64649868011475,
"test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 75.09799966358003,
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.51450157165527,
"test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.21433276221866,
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 73.20024871826172,
"test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 88.1349983215332,
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.89924907684326,
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.32975196838379,
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 120.09600067138672
}

View File

@ -239,6 +239,12 @@ class TestAccelerator(TestCase):
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
def test_get_memory_info(self):
free_bytes, total_bytes = torch.accelerator.get_memory_info()
self.assertGreaterEqual(free_bytes, 0)
self.assertGreaterEqual(total_bytes, 0)
if __name__ == "__main__":
run_tests()

View File

@ -313,15 +313,17 @@ class SerializationMixin:
def test_serialization_gzip(self):
# Test serialization with gzip file
b = self._test_serialization_data()
f1 = tempfile.NamedTemporaryFile(delete=False)
f2 = tempfile.NamedTemporaryFile(delete=False)
torch.save(b, f1)
with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
with tempfile.NamedTemporaryFile() as f1, tempfile.NamedTemporaryFile(delete=False) as f2:
torch.save(b, f1)
f1.seek(0)
with gzip.open(f2.name, 'wb') as f_out:
shutil.copyfileobj(f1, f_out)
with gzip.open(f2.name, 'rb') as f:
c = torch.load(f)
self._test_serialization_assert(b, c)
with gzip.open(f2.name, 'rb') as f:
c = torch.load(f)
self._test_serialization_assert(b, c)
f2.close()
os.unlink(f2.name)
@unittest.skipIf(
not TEST_DILL or HAS_DILL_AT_LEAST_0_3_1,
@ -382,19 +384,19 @@ class SerializationMixin:
def test_serialization_offset_gzip(self):
a = torch.randn(5, 5)
i = 41
f1 = tempfile.NamedTemporaryFile(delete=False)
f2 = tempfile.NamedTemporaryFile(delete=False)
with open(f1.name, 'wb') as f:
pickle.dump(i, f)
torch.save(a, f)
with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
with tempfile.NamedTemporaryFile() as f1:
pickle.dump(i, f1)
torch.save(a, f1)
f1.seek(0)
with gzip.open(f2.name, 'wb') as f_out:
shutil.copyfileobj(f1, f_out)
with gzip.open(f2.name, 'rb') as f:
j = pickle.load(f)
b = torch.load(f)
self.assertTrue(torch.equal(a, b))
self.assertEqual(i, j)
with gzip.open(f2.name, 'rb') as f:
j = pickle.load(f)
b = torch.load(f)
self.assertTrue(torch.equal(a, b))
self.assertEqual(i, j)
def _test_serialization_sparse(self, weights_only):
def _test_serialization(conversion):

View File

@ -3728,7 +3728,6 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(*floating_and_complex_types())
@dtypesIfMPS(*all_mps_types())
@expectedFailureMPS
@dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater and not TEST_WITH_ROCM else [],
*[torch.bfloat16] if SM80OrLater and not TEST_WITH_ROCM else [],
torch.complex64,
@ -3825,9 +3824,9 @@ class TestSparse(TestSparseBase):
def different_dtypes():
a, i_a, v_a = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
b, i_b, v_b = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32))
r2 = torch.sparse.mm(a.to(torch.float32), a.to(torch.float16))
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Float does not match mat2 dtype Half', different_dtypes)
def test_backward_noncontiguous():
# Sparse.mm backward used to wrong with non-contiguous grads,

View File

@ -206,7 +206,8 @@ if __name__ == "__main__":
test_multi_process(model, input)
print(torch.xpu.device_count())
"""
rc = check_output(test_script)
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
rc = check_output(test_script).splitlines()[-1]
self.assertEqual(rc, str(torch.xpu.device_count()))
def test_streams(self):

View File

@ -3,9 +3,9 @@
# mypy: allow-untyped-defs
# ruff: noqa: F401,PYI054
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from types import EllipsisType
from typing import Any, Callable, Literal, overload, TypeVar
from typing import Any, Literal, overload, TypeVar
import torch
from torch import (

View File

@ -2491,6 +2491,7 @@ def _accelerator_emptyCache() -> None: ...
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ...
def _accelerator_setAllocatorSettings(env: str) -> None: ...
# Defined in torch/csrc/jit/python/python_tracer.cpp

View File

@ -22,8 +22,8 @@ import inspect
import sys
import warnings
from collections.abc import Callable, Sequence, Sized
from contextlib import ExitStack
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
from contextlib import AbstractContextManager, ExitStack
from typing import Any, Optional, TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard
@ -163,7 +163,7 @@ class ContextWrappingVariable(VariableTracker):
class GenericContextWrappingVariable(UserDefinedObjectVariable):
# Some methods in ContextWrappingVariable assumes the arguments are
# python constants. Which might not always be the case here.
def __init__(self, cm_obj: ContextManager[Any], **kwargs: Any) -> None:
def __init__(self, cm_obj: AbstractContextManager[Any], **kwargs: Any) -> None:
assert cm_obj is not None
super().__init__(
value=cm_obj,

View File

@ -18,7 +18,8 @@ import collections
import inspect
import operator
import sys
from typing import Any, Optional, Sequence, TYPE_CHECKING
from collections.abc import Sequence
from typing import Any, Optional, TYPE_CHECKING
import torch
import torch.fx

View File

@ -2,7 +2,7 @@
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
from typing import List, Optional, Tuple
from typing import Optional
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,

View File

@ -550,6 +550,10 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
cutedsl_enable_autotuning: bool = (
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
)
# DEPRECATED. This setting is ignored.
autotune_fallback_to_aten = False

View File

@ -1,6 +1,8 @@
# mypy: allow-untyped-defs
import logging
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any
import torch
@ -12,6 +14,7 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
from .. import config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import _IntLike, Layout, TensorBox
from ..utils import load_template
log = logging.getLogger(__name__)
@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
return False
return True
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)

View File

@ -1,11 +1,13 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.runtime.triton_compat import tl
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
from torch._inductor.virtualized import V
from torch.utils._triton import has_triton
@ -22,11 +24,13 @@ from ..utils import (
get_num_sms,
has_free_symbols,
use_aten_gemm_kernels,
use_blackwell_cutedsl_grouped_mm,
use_triton_template,
)
from .mm_common import (
_is_static_problem,
check_supported_striding,
load_kernel_template,
persistent_grouped_mm_grid,
)
@ -513,6 +517,11 @@ triton_scaled_grouped_mm_template = TritonTemplate(
source=triton_grouped_mm_source,
)
cutedsl_grouped_mm_template = CuteDSLTemplate(
name="grouped_gemm_cutedsl",
source=load_kernel_template("cutedsl_mm_grouped"),
)
def grouped_mm_args(
mat1: TensorBox,
@ -714,43 +723,44 @@ def _tuned_grouped_mm_common(
# Checking only for the equality of corresponding dims of
# multiplicands here, relying on meta function checks for
# everything else.
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
if (
is_nonzero
and use_triton_template(layout)
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
):
scaled = scale_a is not None
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
a_is_k_major = mat_a.get_stride()[-1] == 1
b_is_k_major = mat_b.get_stride()[-2] == 1
@ -788,6 +798,22 @@ def _tuned_grouped_mm_common(
**config.kwargs,
)
if use_blackwell_cutedsl_grouped_mm(
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
):
for config in get_groupgemm_configs():
kwargs = dict(
ACC_DTYPE="cutlass.Float32",
)
cutedsl_grouped_mm_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
**asdict(config),
)
input_gen_fns = {
4: lambda x: create_offsets(
x, m1_size, m2_size, offs.get_size() if offs is not None else None

View File

@ -0,0 +1,333 @@
import functools
from torch._inductor.runtime.runtime_utils import ceildiv
from cutlass.utils import TensorMapUpdateMode
{{gen_defines()}}
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
GroupedGemmKernel,
)
# Note about caching:
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
# maintains its own local caching system. At this stage, all compile-time
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
# name itself ({{kernel_name}}) are permanently baked into the file, so they
# do not need to be included in any cache key.
#
# The caching mechanism is split into two levels:
#
# 1. prep_cache
# Caches the compiled executor for build_group_ptrs_from_bases(). This
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
# and can therefore be safely reused across runs with different group
# partitioning (`offs`).
#
# 2. gemm_cache
# Caches the compiled Grouped GEMM executor. Its key extends the prep
# cache key with hardware- and grid-specific parameters:
# (prep_cache_key, max_active_clusters, total_num_clusters).
# This is necessary because different `offs` tensors can change the
# per-group problem sizes and thus alter `total_num_clusters`, which in
# turn changes the grid shape and persistent scheduler configuration.
# Kernels compiled for one grid cannot be safely reused for another.
#
#
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
# despite depending only on the GPU type. We cache this function to mitigate
# redundant recompiles even when shape/stride/dtype cache misses force kernel
# regeneration. A follow-up study will investigate the root cause.
prep_cache = {}
gemm_cache = {}
@functools.lru_cache
def get_hardware_info():
hw = cutlass.utils.HardwareInfo()
sm_count = hw.get_max_active_clusters(1)
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
return (sm_count, max_active_clusters)
def get_prep_cache_key(input_a, input_b, output):
"""
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
shapes, strides, and dtypes of input/output tensors.
"""
return (
tuple(input_a.shape),
tuple(input_a.stride()),
input_a.dtype,
tuple(input_b.shape),
tuple(input_b.stride()),
input_b.dtype,
tuple(output.shape),
tuple(output.stride()),
output.dtype,
)
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
"""
Returns a tuple key for caching the gemm kernel executor by extending the
prep cache key with hardware- and grid-specific parameters.
"""
return (
prep_cache_key,
max_active_clusters,
total_num_clusters,
)
@cute.kernel
def build_group_ptrs_from_bases_kernel(
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Int32, # bytes
# -------- STRIDES (in ELEMENTS) --------
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
# -------- OUTPUTS --------
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
):
tidx, _, _ = cute.arch.thread_idx()
g = tidx
m_beg_i32 = 0
if g > 0:
m_beg_i32 = offs[g - 1]
m_end_i32 = offs[g]
m_g_i32 = m_end_i32 - m_beg_i32
a_byte_off = (
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
)
c_byte_off = (
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
)
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
# ---- pointers ----
out_ptrs[g, 0] = base_A_u64 + a_byte_off
out_ptrs[g, 1] = base_B_u64 + b_byte_off
out_ptrs[g, 2] = base_C_u64 + c_byte_off
# ---- (m, n, k, 1) ----
out_problem[g, 0] = m_g_i32
out_problem[g, 1] = N
out_problem[g, 2] = K
out_problem[g, 3] = cutlass.Int32(1)
# ---- strides ----
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
@cute.jit
def launch_build_group_ptrs_from_bases(
base_A_u64: cutlass.Int64,
base_B_u64: cutlass.Int64,
base_C_u64: cutlass.Int64,
offs: cute.Tensor,
G: cutlass.Constexpr,
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Constexpr,
stride_A_m_elems: cutlass.Constexpr,
stride_A_k_elems: cutlass.Constexpr,
stride_B0_elems: cutlass.Constexpr,
stride_Bk_elems: cutlass.Constexpr,
stride_Bn_elems: cutlass.Constexpr,
stride_C_m_elems: cutlass.Constexpr,
stride_C_n_elems: cutlass.Constexpr,
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
out_problem: cute.Tensor, # [G,4] cutlass.Int32
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
stream: cuda.CUstream,
):
build_group_ptrs_from_bases_kernel(
base_A_u64,
base_B_u64,
base_C_u64,
offs,
K,
N,
sizeof_element,
stride_A_m_elems,
stride_A_k_elems,
stride_B0_elems,
stride_Bk_elems,
stride_Bn_elems,
stride_C_m_elems,
stride_C_n_elems,
out_ptrs,
out_problem,
out_strides_abc,
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
{{def_kernel("input_a", "input_b", "input_a_offs")}}
stream = cuda.CUstream(stream)
input_b = input_b.transpose(1, 2)
sumM, K = input_a.shape
G, N, Kb = input_b.shape
dev = input_a.device
base_A_u64 = int(input_a.data_ptr())
base_B_u64 = int(input_b.data_ptr())
base_C_u64 = int({{get_output()}}.data_ptr())
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
ptrs = from_dlpack(ptrs_t)
probs = from_dlpack(probs_t)
strides = from_dlpack(strides_t)
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
prep_executor = prep_cache.get(prep_cache_key)
if prep_executor is None:
sizeof_element = int(input_a.element_size())
sA_m, sA_k = map(int, input_a.stride())
sB_0, sB_n, sB_k = map(int, input_b.stride())
sC_m, sC_n = map(int, {{get_output()}}.stride())
prep_executor = cute.compile(
launch_build_group_ptrs_from_bases,
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
G=int(G),
K=int(K),
N=int(N),
sizeof_element=sizeof_element,
stride_A_m_elems=sA_m,
stride_A_k_elems=sA_k,
stride_B0_elems=sB_0,
stride_Bk_elems=sB_k,
stride_Bn_elems=sB_n,
stride_C_m_elems=sC_m,
stride_C_n_elems=sC_n,
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
prep_cache[prep_cache_key] = prep_executor
prep_executor(
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
# --- Tensormap workspace per SM ---
num_tensormap_buffers, max_active_clusters = get_hardware_info()
tensormap_shape = (
num_tensormap_buffers,
GroupedGemmKernel.num_tensormaps,
GroupedGemmKernel.bytes_per_tensormap // 8,
)
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
tensormap_workspace = from_dlpack(tensormap_workspace_t)
# --- Total clusters ---
def compute_total_num_clusters(
problem_sizes_mnkl,
cluster_tile_shape_mn,
):
total_num_clusters = 0
for m, n, _, _ in problem_sizes_mnkl:
num_clusters_mn = tuple(
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
)
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
return total_num_clusters
# Compute cluster tile shape
def compute_cluster_tile_shape(
mma_tiler_mn,
cluster_shape_mn,
use_2cta_instrs,
):
cta_tile_shape_mn = list(mma_tiler_mn)
if use_2cta_instrs:
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
cluster_tile_shape_mn = compute_cluster_tile_shape(
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
)
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
gemm_cache_key = get_gemm_cache_key(
prep_cache_key, max_active_clusters, total_num_clusters
)
gemm_executor = gemm_cache.get(gemm_cache_key)
if gemm_executor is None:
grouped_gemm = GroupedGemmKernel(
acc_dtype=ACC_DTYPE,
use_2cta_instrs=USE_2_CTA,
mma_tiler_mn=(TILE_M, TILE_N),
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
)
gemm_executor = cute.compile(
grouped_gemm,
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
G,
probs,
strides,
ptrs,
total_num_clusters,
tensormap_workspace,
max_active_clusters,
stream,
)
gemm_cache[gemm_cache_key] = gemm_executor
gemm_executor(
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
probs,
strides,
ptrs,
tensormap_workspace,
stream,
)

View File

@ -0,0 +1,141 @@
from dataclasses import dataclass
from enum import auto, Enum
from itertools import product
import torch._inductor.config as config
class TensorMapUpdateMode(Enum):
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
SMEM = auto()
GMEM = auto()
@dataclass(frozen=True)
class CuTeGemmConfig:
TILE_M: int = 128
TILE_N: int = 192
CLUSTER_M: int = 2
CLUSTER_N: int = 1
USE_2_CTA: bool = False
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
For information regarding valid config sets, see:
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
"""
# Tile_n is always the same regardless of 2cta
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
# Valid clusters
clusters_no_2cta = [
(1, 1),
(1, 2),
(1, 4),
(1, 8),
(1, 16),
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
clusters_2cta = [
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
configs: list[CuTeGemmConfig] = []
for use_2cta, cluster_set, tile_m_range in [
(False, clusters_no_2cta, [64, 128]),
(True, clusters_2cta, [128, 256]),
]:
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
tile_m_range,
tile_n_vals,
cluster_set,
):
configs.append(
CuTeGemmConfig(
tile_m,
tile_n,
cluster_m,
cluster_n,
USE_2_CTA=use_2cta,
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
)
)
return configs
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
"""
config_tuples = [
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
]
return [CuTeGemmConfig(*args) for args in config_tuples]
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
or unstable results. By default, autotuning is disabled and we return only
a single baseline config.
"""
if (
config.cutedsl_enable_autotuning
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
):
return get_exhaustive_groupgemm_configs()
elif config.cutedsl_enable_autotuning:
return get_default_groupgemm_configs()
else:
return [get_default_groupgemm_configs()[0]]

View File

@ -1911,6 +1911,84 @@ def use_triton_blackwell_tma_template(
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
@functools.lru_cache(maxsize=1)
def ensure_cute_available() -> bool:
"""Check if CuTeDSL is importable; cache the result for reuse.
Call ensure_cute_available.cache_clear() after installing CuTeDSL
in the same interpreter to retry the import.
"""
try:
return importlib.util.find_spec("cutlass.cute") is not None
except ImportError:
return False
def use_blackwell_cutedsl_grouped_mm(
mat_a: Any,
mat_b: Any,
layout: Layout,
a_is_2d: bool,
b_is_2d: bool,
offs: Optional[Any],
bias: Optional[Any],
scale_result: Optional[Any],
) -> bool:
"""
Returns True if we can use the blackwell kernel for grouped mm.
Required conditions:
1. CuTeDSL backend is enabled
2. CuTeDSL is available
3. We are on a blackwell arch
4. The dtype is bf16
5. Max autotune or max autotune gemm is enabled
6. A, B, and the output are 16B aligned
7. We are not using dynamic shapes
8. A is 2d
9. B is 3d
10. Offsets are provided
11. Bias and Scale are not provided
"""
if not ensure_cute_available():
return False
if not _use_autotune_backend("CUTEDSL"):
return False
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
if not is_gpu(layout.device.type):
return False
if not is_datacenter_blackwell_arch():
return False
layout_dtypes = [torch.bfloat16]
if not _use_template_for_gpu(layout, layout_dtypes):
return False
if not (config.max_autotune or config.max_autotune_gemm):
return False
# Checks for 16B ptr and stride alignment
if not can_use_tma(mat_a, mat_b, output_layout=layout):
return False
if any(is_dynamic(x) for x in [mat_a, mat_b]):
return False
if not a_is_2d or b_is_2d:
return False
if offs is None:
return False
if bias is not None or scale_result is not None:
return False
return True
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
from .virtualized import V
@ -2651,7 +2729,6 @@ def pass_execution_and_save(
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
delete=False,
) as f:
before_io = io.StringIO()
after_io = io.StringIO()

View File

@ -10,6 +10,7 @@ import torch
from ._utils import _device_t, _get_device_index
from .memory import (
empty_cache,
get_memory_info,
max_memory_allocated,
max_memory_reserved,
memory_allocated,
@ -25,9 +26,10 @@ __all__ = [
"current_device_idx", # deprecated
"current_device_index",
"current_stream",
"empty_cache",
"device_count",
"device_index",
"empty_cache",
"get_memory_info",
"is_available",
"max_memory_allocated",
"max_memory_reserved",

View File

@ -8,6 +8,7 @@ from ._utils import _device_t, _get_device_index
__all__ = [
"empty_cache",
"get_memory_info",
"max_memory_allocated",
"max_memory_reserved",
"memory_allocated",
@ -87,6 +88,9 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values.
"""
if not torch._C._accelerator_isAllocatorInitialized():
return OrderedDict()
@ -117,6 +121,9 @@ def memory_allocated(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the current memory occupied by live tensors (in bytes) within the current process.
"""
return memory_stats(device_index).get("allocated_bytes.all.current", 0)
@ -134,6 +141,9 @@ def max_memory_allocated(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the peak memory occupied by live tensors (in bytes) within the current process.
"""
return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
@ -147,6 +157,9 @@ def memory_reserved(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the current memory reserved by PyTorch (in bytes) within the current process.
"""
return memory_stats(device_index).get("reserved_bytes.all.current", 0)
@ -164,6 +177,9 @@ def max_memory_reserved(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the peak memory reserved by PyTorch (in bytes) within the current process.
"""
return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
@ -200,3 +216,21 @@ def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
"""
device_index = _get_device_index(device_index, optional=True)
return torch._C._accelerator_resetPeakStats(device_index)
def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]:
r"""Return the current device memory information for a given device index.
Args:
device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes.
The first value is the free memory on the device (available across all processes and applications),
The second value is the device's total hardware memory capacity.
"""
device_index = _get_device_index(device_index, optional=True)
return torch._C._accelerator_getMemoryInfo(device_index)

View File

@ -195,10 +195,12 @@ def get_new_attr_name_with_prefix(prefix: str) -> Callable:
def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
r"""Starting from a target node, trace back until we hit input or
getattr node. This is used to extract the chain of operators
starting from getattr to the target node, for example
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
starting from getattr to the target node, for example::
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
collect_producer_nodes(observed) will either return a list of nodes that
produces the observed node or None if we can't extract a self contained
graph without free variables(inputs of the forward function).

View File

@ -138,6 +138,13 @@ void initModule(PyObject* module) {
at::accelerator::resetPeakStats(device_index);
});
m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
py::gil_scoped_release no_gil;
return at::accelerator::getMemoryInfo(device_index);
});
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
c10::CachingAllocator::setAllocatorSettings(env);
});

View File

@ -122,29 +122,47 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT
ID get_tensor_storage_ID(const c10::Storage& t_storage) {
const std::lock_guard<std::recursive_mutex> lock(gMutex);
const void* raw_data_ptr = t_storage.data();
auto iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
if (iter == data_ptr_to_weak_storage_ptr.end()) {
const void* raw_data_ptr = nullptr;
bool should_track_liveness = false;
// FakeTensor/FunctionalTensor may clear the Storage handle entirely or use
// a nullptr data pointer. Treat both cases as a shared cache key but avoid
// touching the weak-ref table so they can reuse the same ID without
// tripping the liveness check.
if (t_storage.unsafeGetStorageImpl()) {
raw_data_ptr = t_storage.data();
should_track_liveness = raw_data_ptr != nullptr;
}
auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
if (!should_track_liveness) {
if (id_iter != data_ptr_to_storage_id.end()) {
return id_iter->second;
}
ID id = storage_id_++;
data_ptr_to_storage_id.emplace(raw_data_ptr, id);
return id;
}
auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
if (weak_iter == data_ptr_to_weak_storage_ptr.end()) {
ID id = storage_id_++;
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
data_ptr_to_weak_storage_ptr.emplace(
raw_data_ptr, t_storage.getWeakStorageImpl());
return id;
} else {
// check if the storage is still alive
if (iter->second.expired()) {
ID id = storage_id_++;
// std::unorder_map does not change if the key is already in the map.
// So we need to remove the key and insert the key with the new value.
data_ptr_to_storage_id.erase(raw_data_ptr);
data_ptr_to_storage_id[raw_data_ptr] = id;
data_ptr_to_weak_storage_ptr.insert_or_assign(
raw_data_ptr, t_storage.getWeakStorageImpl());
return id;
} else {
return data_ptr_to_storage_id[raw_data_ptr];
}
}
if (weak_iter->second.expired()) {
ID id = storage_id_++;
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
data_ptr_to_weak_storage_ptr.insert_or_assign(
raw_data_ptr, t_storage.getWeakStorageImpl());
return id;
}
id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end());
return id_iter->second;
}
// Observer run state.

View File

@ -386,23 +386,8 @@ static void bindGetDeviceProperties(PyObject* module) {
static void initXpuMethodBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
#if SYCL_COMPILER_VERSION >= 20250000
auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size;
auto& device = c10::xpu::get_raw_device(device_index);
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
at::xpu::getDeviceProperties(device_index)->name,
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
auto free = device.get_info<sycl::ext::intel::info::device::free_memory>();
return std::make_tuple(free, total);
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
#endif
py::gil_scoped_release no_gil;
return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index);
});
m.def(
"_xpu_getStreamFromExternal",

View File

@ -1,4 +1,5 @@
import itertools
import math
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, cast, NamedTuple, Optional
@ -7,6 +8,7 @@ import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
_StridedShard,
MaskPartial,
Partial,
Placement,
Replicate,
@ -127,6 +129,185 @@ class DTensorSpec:
)
return default_shard_order
@staticmethod
def _convert_shard_order_to_StridedShard(
shard_order: ShardOrder, placements: tuple[Placement, ...], mesh: DeviceMesh
) -> tuple[Placement, ...]:
"""
Convert ShardOrder to placements with _StridedShard.
This function converts a ShardOrder specification into a tuple of Placement objects,
using _StridedShard when a tensor dimension is sharded across multiple mesh dimensions
in a non-default order. The split_factor of each _StridedShard is determined by the
product of mesh dimension sizes that appear earlier in the shard order but later in
the placement tuple.
Args:
shard_order: ShardOrder specification indicating which tensor dimensions are
sharded on which mesh dimensions and in what execution order.
placements: Tuple of Placement objects that does not contain _StridedShard.
mesh: DeviceMesh containing the size information for each mesh dimension.
Returns:
Updated tuple of Placement objects with Shard or _StridedShard placements.
Algorithm:
For each ShardOrderEntry in shard_order:
- For each mesh dimension in the entry's mesh_dims (in order):
- Calculate split_factor as the product of mesh sizes for all mesh dimensions
that appear:
1. Earlier in the shard order (lower index in mesh_dims), and
2. Later in the placement tuple (higher mesh dimension index)
- If split_factor == 1: use normal Shard
- Otherwise: use _StridedShard with the calculated split_factor
Example:
>>> # xdoctest: +SKIP("Requires DeviceMesh")
>>> # Tensor dimension 0 sharded on mesh dims [2, 0, 1] in that order
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
>>> shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),)
>>> placements = (Shard(0), Shard(0), Shard(0))
>>> # For mesh_dim=2 (index 0 in mesh_dims): no earlier dims, split_factor=1
>>> # -> placements[2] = Shard(0)
>>> # For mesh_dim=0 (index 1 in mesh_dims): mesh_dim=2 is earlier and has index 2>0
>>> # -> split_factor = mesh.size(2) = 2
>>> # -> placements[0] = _StridedShard(0, split_factor=2)
>>> # For mesh_dim=1 (index 2 in mesh_dims): mesh_dim=2 is earlier and has index 2>1
>>> # -> split_factor = mesh.size(2) = 2
>>> # -> placements[1] = _StridedShard(0, split_factor=2)
>>> # Result: (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
"""
placements_list = list(placements)
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for idx in range(len(mesh_dims)):
# TODO(zpcore): split_factor from `view` and `shard order`
# should be able to be multiplied into one. Need to loosen the
# condition here.
mesh_dim = mesh_dims[idx]
if type(placements[mesh_dim]) is not Shard:
raise ValueError(
f"Only Shard placement can be converted to _StridedShard, "
f"found {placements[mesh_dim]} in {placements=}."
)
split_factor = math.prod(
mesh.size(i) for i in mesh_dims[:idx] if i > mesh_dim
)
if split_factor == 1:
# use normal Shard
placements_list[mesh_dim] = Shard(tensor_dim)
else:
placements_list[mesh_dim] = _StridedShard(
tensor_dim, split_factor=split_factor
)
return tuple(placements_list)
@staticmethod
def _maybe_convert_StridedShard_to_shard_order(
placements: tuple[Placement, ...], mesh: DeviceMesh
) -> Optional[ShardOrder]:
"""
Try to convert _StridedShard placements to ShardOrder.
This is the inverse of `_convert_shard_order_to_StridedShard`. It reconstructs the shard
order by examining the split_factor of each _StridedShard and determining its position
in the execution order. If the _StridedShard configuration cannot be represented as a
valid ShardOrder (i.e., there's no shard order that produces the observed split_factors),
this function returns None.
Args:
placements: Tuple of Placement objects that may contain _StridedShard.
mesh: DeviceMesh containing the size information for each mesh dimension.
Returns:
ShardOrder if conversion is possible, None otherwise. For placements without
_StridedShard, returns the default shard order.
Algorithm:
1. If no _StridedShard in placements, return default shard order
2. Create an empty list for each tensor dimension to represent mesh dim ordering
3. Iterate through placements in reverse order (right to left):
- For each Shard/_StridedShard on a tensor dimension:
- Extract its split_factor (1 for Shard, split_factor for _StridedShard)
- Find the position in mesh_dims_order where accumulated_sf equals split_factor
- accumulated_sf is the product of mesh sizes of mesh dimensions that appear
earlier in mesh_dims_order (lower indices)
- Insert mesh_dim at the found position
4. If no valid position found for any split_factor, return None (unable to convert)
5. Construct ShardOrderEntry for each tensor dimension from mesh_dims_order
Example:
>>> # xdoctest: +SKIP("Requires DeviceMesh")
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
>>> # placements = (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
>>> # Process tensor_dim=0 from right to left:
>>> # - mesh_dim=2: Shard(0) with sf=1
>>> # Try position 0: accumulated_sf=1, matches! Insert at position 0
>>> # Current mesh_dims_order order: [2]
>>> # - mesh_dim=1: _StridedShard(0, sf=2) with sf=2
>>> # Try position 0: accumulated_sf=1, no match
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
>>> # Current mesh_dims_order order: [2, 1]
>>> # - mesh_dim=0: _StridedShard(0, sf=2) with sf=2
>>> # Try position 0: accumulated_sf=1, no match
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
>>> # Final mesh_dims_order order: [2, 0, 1]
>>> # Result: ShardOrder((ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),))
>>> # This means: first shard on mesh_dim=2, then mesh_dim=0, then mesh_dim=1
Note:
This function validates that _StridedShard can be represented as a ShardOrder.
Not all _StridedShard configurations are valid - the split_factor must match
the product of mesh sizes in some execution order.
"""
if not any(isinstance(p, _StridedShard) for p in placements):
return DTensorSpec.compute_default_shard_order(placements)
max_tensor_dim = (
max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1
)
shard_order = []
tensor_dim_to_mesh_dims_order: list[list[int]] = [
[] for i in range(max_tensor_dim)
]
for mesh_dim in reversed(range(len(placements))):
cur_placement = placements[mesh_dim]
# _StridedShard may not be a subclass of Shard in the future, so write in this way:
if isinstance(cur_placement, Shard | _StridedShard):
tensor_dim = cur_placement.dim
mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim]
cur_sf = 1
if isinstance(cur_placement, _StridedShard):
cur_sf = cur_placement.split_factor
accumulated_sf = 1
find_order = False
for i in range(len(mesh_dims_order) + 1):
if accumulated_sf == cur_sf:
mesh_dims_order.insert(i, mesh_dim)
find_order = True
break
if i < len(mesh_dims_order):
accumulated_sf *= mesh.size(mesh_dims_order[i])
if not find_order:
# _StridedShard is not convertible to ShardOrder
return None
else:
if not isinstance(cur_placement, Replicate | Partial | MaskPartial):
raise ValueError(
f"Unsupported placement type {type(cur_placement)} encountered in "
f"{placements}; expected Replicate, Partial, or MaskPartial."
)
for tensor_dim in range(max_tensor_dim):
if len(tensor_dim_to_mesh_dims_order[tensor_dim]) > 0:
shard_order.append(
ShardOrderEntry(
tensor_dim=tensor_dim,
mesh_dims=tuple(tensor_dim_to_mesh_dims_order[tensor_dim]),
)
)
return tuple(shard_order)
def _verify_shard_order(self, shard_order: ShardOrder) -> None:
"""Verify that the shard_order is valid and matches the placements."""
total_shard = 0

View File

@ -11,7 +11,10 @@ import torch.distributed.tensor._api as dtensor
aten = torch.ops.aten
def _requires_data_exchange(padding):
def _requires_data_exchange(padding, dim_map) -> bool:
# Data exchange is not need if only sharded across batch dim
if all(x == -1 for x in dim_map[1:]):
return False
# TODO: whether there requires data exchange is currently determined by padding
return padding[-1] != 0
@ -107,6 +110,7 @@ def tp_convolution(
op_call: torch._ops.OpOverload,
local_tensor_args: tuple[object, ...],
local_tensor_kwargs: dict[str, object],
dim_map: list[int],
) -> object:
assert op_call == aten.convolution.default
assert len(local_tensor_args) == 9
@ -120,7 +124,7 @@ def tp_convolution(
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, list)
if not _requires_data_exchange(padding):
if not _requires_data_exchange(padding, dim_map):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
@ -160,6 +164,7 @@ def tp_convolution_backward(
op_call: torch._ops.OpOverload,
local_tensor_args: tuple[object, ...],
local_tensor_kwargs: dict[str, object],
dim_map: list[int],
) -> object:
assert op_call == aten.convolution_backward.default
assert len(local_tensor_args) == 11
@ -174,7 +179,7 @@ def tp_convolution_backward(
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, list)
if not _requires_data_exchange(padding):
if not _requires_data_exchange(padding, dim_map):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
@ -239,15 +244,18 @@ def convolution_handler(
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
output_spec = output_sharding.output_spec
assert isinstance(output_spec, dtensor.DTensorSpec)
# local propagation
local_results = tp_convolution(
op_call, tuple(op_info.local_args), op_info.local_kwargs
op_call,
tuple(op_info.local_args),
op_info.local_kwargs,
output_spec.dim_map,
)
return dtensor.DTensor._op_dispatcher.wrap(
local_results, output_sharding.output_spec
)
return dtensor.DTensor._op_dispatcher.wrap(local_results, output_spec)
def convolution_backward_handler(
@ -270,10 +278,14 @@ def convolution_backward_handler(
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
assert isinstance(op_info.flat_args_schema[0], dtensor.DTensorSpec)
# local propagation
local_results = tp_convolution_backward(
op_call, tuple(op_info.local_args), op_info.local_kwargs
op_call,
tuple(op_info.local_args),
op_info.local_kwargs,
op_info.flat_args_schema[0].dim_map,
)
return dtensor.DTensor._op_dispatcher.wrap(

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from collections.abc import Callable
from typing import Any, Optional, Union
number = Union[int, float]

View File

@ -1,9 +1,8 @@
# ${generated_comment}
# mypy: allow-untyped-defs
from collections.abc import Sequence
from typing import Any, Callable, Literal, overload
from typing_extensions import TypeAlias
from collections.abc import Callable, Sequence
from typing import Any, Literal, overload, TypeAlias
from torch import Tensor
from torch.types import _dtype, _int, _size

View File

@ -6,7 +6,7 @@
from __future__ import annotations
from typing import Optional, Sequence, TYPE_CHECKING
from typing import Optional, Sequence, TYPE_CHECKING # noqa: UP035
from onnxscript.onnx_opset import ( # type: ignore[attr-defined]
opset20 as op20,

View File

@ -271,13 +271,11 @@ class _KinetoProfile:
"Profiler must be initialized before exporting chrome trace"
)
if path.endswith(".gz"):
with tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) as fp:
fp.close()
with tempfile.NamedTemporaryFile("w+b", suffix=".json") as fp:
retvalue = self.profiler.export_chrome_trace(fp.name)
with open(fp.name, "rb") as fin:
with gzip.open(path, "wb") as fout:
fout.writelines(fin)
os.remove(fp.name)
fp.seek(0)
with gzip.open(path, "wb") as fout:
fout.writelines(fp)
return retvalue
else:
return self.profiler.export_chrome_trace(path)
@ -448,15 +446,14 @@ class _KinetoProfile:
if path.endswith(".html"):
self.mem_tl.export_memory_timeline_html(path, device)
elif path.endswith(".gz"):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
fp.close()
if path.endswith("raw.json.gz"):
self.mem_tl.export_memory_timeline_raw(fp.name, device)
else:
self.mem_tl.export_memory_timeline(fp.name, device)
with open(fp.name) as fin, gzip.open(path, "wt") as fout:
fout.writelines(fin)
os.remove(fp.name)
with tempfile.NamedTemporaryFile("w+t", suffix=".json") as fp:
fp.close()
if path.endswith("raw.json.gz"):
self.mem_tl.export_memory_timeline_raw(fp.name, device)
else:
self.mem_tl.export_memory_timeline(fp.name, device)
with open(fp.name) as fin, gzip.open(path, "wt") as fout:
fout.writelines(fin)
else:
self.mem_tl.export_memory_timeline(path, device)
@ -946,7 +943,7 @@ class ExecutionTraceObserver(_ITraceObserver):
"""
if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE", "0") == "1":
try:
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) # noqa:SIM115
except Exception as e:
warn(
f"Execution trace will not be recorded. Exception on creating default temporary file: {e}",

View File

@ -20320,6 +20320,7 @@ op_db: list[OpInfo] = [
torch.float32: 1e-4}),),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
supports_sparse=True,
supports_sparse_csr=True,
supports_sparse_csc=True,
supports_sparse_bsr=True,

View File

@ -3,6 +3,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import copy
import functools
import itertools
import sys
@ -32,6 +33,8 @@ from torch.distributed.tensor import (
Replicate,
Shard,
)
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
@ -818,3 +821,125 @@ def map_local_for_rank(rank, func):
def reduce_local_int(val, func):
return func(val.node._local_ints)
def _convert_shard_order_dict_to_ShardOrder(shard_order):
"""Convert shard_order dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
for tensor_dim, mesh_dims in shard_order.items()
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def redistribute(
dtensor_input,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""
wrapper function to support shard_order for redistribution
This is a simpler version of Redistribute, only considers the forward.
"""
if placements is None:
placements = shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
old_spec = dtensor_input._spec
new_spec = copy.deepcopy(old_spec)
new_spec.placements = placements
if shard_order is not None:
new_spec.shard_order = shard_order
else:
new_spec.shard_order = ()
if old_spec == new_spec:
return dtensor_input
dtensor_input = DTensor.from_local(
redistribute_local_tensor(
dtensor_input.to_local(),
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)
return dtensor_input # returns DTensor
# TODO(zpcore): remove once the native distribute_tensor supports
# shard_order arg
def patched_distribute_tensor(
input_tensor,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""wrapper function to support shard_order for tensor distribution"""
if placements is None:
placements = shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
# fix the shard order
return redistribute(
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def make_full_tensor(dtensor_input):
"""wrapper function to support DTensor.full_tensor"""
return redistribute(
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
).to_local()
def shard_order_to_placement(shard_order, mesh):
"""convert shard_order to placement with only Replicate() and Shard()"""
placements: list[Any] = [Replicate() for _ in range(mesh.ndim)]
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
def generate_shard_orders(mesh, tensor_rank):
# Generate all possible sharding placement of tensor with rank
# `tensor_rank` over mesh.
def _split_list(lst: list, N: int):
def compositions(n: int, k: int):
# yields lists of length k, positive ints summing to n
for cuts in itertools.combinations(range(1, n), k - 1):
# add 0 and n as sentinels, then take consecutive differences
yield [b - a for a, b in itertools.pairwise((0, *cuts, n))]
length = len(lst)
for comp in compositions(length, N):
result = []
start = 0
for size in comp:
result.append(lst[start : start + size])
start += size
yield result
all_mesh = list(range(mesh.ndim))
all_device_order = list(itertools.permutations(all_mesh))
for device_order in all_device_order:
# split on device orders, and assign each device order segment to a tensor dim
for num_split in range(1, mesh.ndim + 1):
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
for tensor_dims in itertools.combinations(
range(tensor_rank), len(splitted_list)
):
shard_order = {}
assert len(tensor_dims) == len(splitted_list)
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
shard_order[tensor_dim] = device_order[
mesh_dims[0] : mesh_dims[-1] + 1
]
yield _convert_shard_order_dict_to_ShardOrder(shard_order)

View File

@ -215,19 +215,16 @@ def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False):
def get_profiler_nccl_meta(prof):
"""Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
We will need to test metadata obtained from profiler here"""
tf = tempfile.NamedTemporaryFile(mode="w+t", suffix=".json", delete=False)
tf.close()
trace_file = tf.name
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".json") as tf:
tf.close()
trace_file = tf.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
events = json.load(f)["traceEvents"]
print(f"Trace saved to {trace_file}")
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
events = json.load(f)["traceEvents"]
print(f"Trace saved to {trace_file}")
# Comment to debug
os.remove(trace_file)
return [e for e in events if e.get("name") == "record_param_comms"]
return [e for e in events if e.get("name") == "record_param_comms"]
# Base error message substring on unfinished reductions.

View File

@ -1,5 +1,5 @@
import sys
from typing import Callable, Optional
from typing import Callable, Optional # noqa: UP035
from torch.utils._config_module import install_config_module

View File

@ -11,7 +11,7 @@ import unittest
from collections.abc import Callable
from dataclasses import dataclass
from types import FunctionType, ModuleType
from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import deprecated
from unittest import mock
@ -23,7 +23,7 @@ CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
# Duplicated, because mypy needs these types statically
T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict])
T = TypeVar("T", bound=int | float | bool | None | str | list | set | tuple | dict)
_UNSET_SENTINEL = object()
@ -69,12 +69,12 @@ class _Config(Generic[T]):
default behaviour. I.e. user overrides take preference.
"""
default: Union[T, object]
justknob: Optional[str] = None
env_name_default: Optional[list[str]] = None
env_name_force: Optional[list[str]] = None
value_type: Optional[type] = None
alias: Optional[str] = None
default: T | object
justknob: str | None = None
env_name_default: list[str] | None = None
env_name_force: list[str] | None = None
value_type: type | None = None
alias: str | None = None
def __post_init__(self) -> None:
self.env_name_default = _Config.string_or_list_of_string_to_list(
@ -98,8 +98,8 @@ class _Config(Generic[T]):
@staticmethod
def string_or_list_of_string_to_list(
val: Optional[Union[str, list[str]]],
) -> Optional[list[str]]:
val: str | list[str] | None,
) -> list[str] | None:
if val is None:
return None
if isinstance(val, str):
@ -116,23 +116,23 @@ class _Config(Generic[T]):
if TYPE_CHECKING:
def Config(
default: Union[T, object] = _UNSET_SENTINEL,
justknob: Optional[str] = None,
env_name_default: Optional[Union[str, list[str]]] = None,
env_name_force: Optional[Union[str, list[str]]] = None,
value_type: Optional[type] = None,
alias: Optional[str] = None,
default: T | object = _UNSET_SENTINEL,
justknob: str | None = None,
env_name_default: str | list[str] | None = None,
env_name_force: str | list[str] | None = None,
value_type: type | None = None,
alias: str | None = None,
) -> T: ...
else:
def Config(
default: Union[T, object] = _UNSET_SENTINEL,
justknob: Optional[str] = None,
env_name_default: Optional[Union[str, list[str]]] = None,
env_name_force: Optional[Union[str, list[str]]] = None,
value_type: Optional[type] = None,
alias: Optional[str] = None,
default: T | object = _UNSET_SENTINEL,
justknob: str | None = None,
env_name_default: str | list[str] | None = None,
env_name_force: str | list[str] | None = None,
value_type: type | None = None,
alias: str | None = None,
) -> _Config[T]:
return _Config(
default=default,
@ -144,7 +144,7 @@ else:
)
def _read_env_variable(name: str) -> Optional[Union[bool, str]]:
def _read_env_variable(name: str) -> bool | str | None:
value = os.environ.get(name)
if value == "1":
return True
@ -165,8 +165,8 @@ def install_config_module(module: ModuleType) -> None:
_bypass_keys = set({"_is_dirty", "_hash_digest", "__annotations__"})
def visit(
source: Union[ModuleType, type],
dest: Union[ModuleType, SubConfigProxy],
source: ModuleType | type,
dest: ModuleType | SubConfigProxy,
prefix: str,
) -> None:
"""Walk the module structure and move everything to module._config"""
@ -281,7 +281,7 @@ class _ConfigEntry:
# _UNSET_SENTINEL indicates the value is not set.
user_override: Any = _UNSET_SENTINEL
# The justknob to check for this config
justknob: Optional[str] = None
justknob: str | None = None
# environment variables are read at install time
env_value_force: Any = _UNSET_SENTINEL
env_value_default: Any = _UNSET_SENTINEL
@ -297,7 +297,7 @@ class _ConfigEntry:
# call so the final state is correct. It's just very unintuitive.
# upstream bug - python/cpython#126886
hide: bool = False
alias: Optional[str] = None
alias: str | None = None
def __init__(self, config: _Config) -> None:
self.default = config.default
@ -347,7 +347,7 @@ class ConfigModule(ModuleType):
_bypass_keys: set[str]
_compile_ignored_keys: set[str]
_is_dirty: bool
_hash_digest: Optional[bytes]
_hash_digest: bytes | None
def __init__(self) -> None:
raise NotImplementedError(
@ -411,7 +411,7 @@ class ConfigModule(ModuleType):
def _get_alias_module_and_name(
self, entry: _ConfigEntry
) -> Optional[tuple[ModuleType, str]]:
) -> tuple[ModuleType, str] | None:
alias = entry.alias
if alias is None:
return None
@ -465,8 +465,8 @@ class ConfigModule(ModuleType):
def _get_dict(
self,
ignored_keys: Optional[list[str]] = None,
ignored_prefixes: Optional[list[str]] = None,
ignored_keys: list[str] | None = None,
ignored_prefixes: list[str] | None = None,
skip_default: bool = False,
) -> dict[str, Any]:
"""Export a dictionary of current configuration keys and values.
@ -542,7 +542,7 @@ class ConfigModule(ModuleType):
if module_name:
imports.add(module_name)
def list_of_callables_to_string(v: Union[list, set]) -> list[str]:
def list_of_callables_to_string(v: list | set) -> list[str]:
return [f"{get_module_name(item, True)}{item.__name__}" for item in v]
def importable_callable(v: Any) -> bool:
@ -615,7 +615,7 @@ class ConfigModule(ModuleType):
def shallow_copy_dict(self) -> dict[str, Any]:
return self.get_config_copy()
def load_config(self, maybe_pickled_config: Union[bytes, dict[str, Any]]) -> None:
def load_config(self, maybe_pickled_config: bytes | dict[str, Any]) -> None:
"""Restore from a prior call to save_config() or shallow_copy_dict()"""
if not isinstance(maybe_pickled_config, dict):
config = pickle.loads(maybe_pickled_config)
@ -637,7 +637,7 @@ class ConfigModule(ModuleType):
def patch(
self,
arg1: Optional[Union[str, dict[str, Any]]] = None,
arg1: str | dict[str, Any] | None = None,
arg2: Any = None,
**kwargs: dict[str, Any],
) -> "ContextDecorator":
@ -816,7 +816,7 @@ def patch_object(obj: object, name: str, value: object) -> object:
return mock.patch.object(obj, name, value)
def get_tristate_env(name: str, default: Any = None) -> Optional[bool]:
def get_tristate_env(name: str, default: Any = None) -> bool | None:
value = os.environ.get(name)
if value == "1":
return True

View File

@ -34,7 +34,6 @@ import hashlib
import os.path
import struct
from collections import defaultdict
from typing import Optional
import torch
import torch._prims as prims
@ -193,9 +192,9 @@ class ContentStoreWriter:
class ContentStoreReader:
def __init__(self, loc: str, *, cache=True) -> None:
self.loc = loc
self.storage_cache: Optional[
dict[Optional[torch.device], dict[str, StorageWeakRef]]
] = None
self.storage_cache: (
dict[torch.device | None, dict[str, StorageWeakRef]] | None
) = None
if cache:
self.storage_cache = defaultdict(dict)
@ -207,7 +206,7 @@ class ContentStoreReader:
if self.storage_cache is not None
else None
)
s: Optional[torch.UntypedStorage]
s: torch.UntypedStorage | None
if ws is not None:
s = torch.UntypedStorage._new_with_weak_ptr(ws.cdata)
if s is not None:

View File

@ -1,10 +1,9 @@
from collections.abc import Sequence
from pathlib import Path
from re import match as _match
from typing import Optional, Union
def read_file(fname: Union[Path, str]) -> list[str]:
def read_file(fname: Path | str) -> list[str]:
with open(fname, encoding="utf-8") as f:
return f.readlines()
@ -36,7 +35,7 @@ def _embed_headers(
def embed_headers(
fname: str, include_dirs: Optional[Union[Sequence[str], Sequence[Path], str]] = None
fname: str, include_dirs: Sequence[str] | Sequence[Path] | str | None = None
) -> str:
if include_dirs is None:
base_dir = Path(__file__).parent.parent.parent

View File

@ -15,7 +15,7 @@ collection support for PyTorch APIs.
import functools
import types
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Optional, overload, TypeAlias, TypeVar, Union
from typing import Any, overload, TypeAlias, TypeVar, Union
from typing_extensions import deprecated, Self, TypeIs
import torch.utils._pytree as python_pytree
@ -128,10 +128,10 @@ def register_pytree_node(
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""Register a container-like type as pytree node.
@ -196,9 +196,9 @@ def _register_pytree_node(
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
) -> None:
"""Register a container-like type as pytree node for the C++ pytree only.
@ -247,9 +247,9 @@ def _private_register_pytree_node(
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
) -> None:
"""This is an internal function that is used to register a pytree node type
for the C++ pytree only. End-users should use :func:`register_pytree_node`
@ -281,7 +281,7 @@ def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
def treespec_dict(
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
mapping: Mapping[Any, TreeSpec] | Iterable[tuple[Any, TreeSpec]] = (),
/,
**kwargs: TreeSpec,
) -> TreeSpec:
@ -296,7 +296,7 @@ def treespec_dict(
def tree_is_leaf(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
"""Check if a pytree is a leaf.
@ -334,7 +334,7 @@ def tree_is_leaf(
def tree_flatten(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> tuple[list[Any], TreeSpec]:
"""Flatten a pytree.
@ -399,7 +399,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
def tree_iter(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree.
@ -434,7 +434,7 @@ def tree_iter(
def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any]:
"""Get the leaves of a pytree.
@ -469,7 +469,7 @@ def tree_leaves(
def tree_structure(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> TreeSpec:
"""Get the treespec for a pytree.
@ -506,7 +506,7 @@ def tree_map(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
"""Map a multi-input function over pytree args to produce a new pytree.
@ -555,7 +555,7 @@ def tree_map_(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
@ -593,8 +593,8 @@ Type2 = tuple[type[T], type[S]]
Type3 = tuple[type[T], type[S], type[U]]
TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
Fn2 = Callable[[Union[T, S]], R]
Fn3 = Callable[[Union[T, S, U]], R]
Fn2 = Callable[[T | S], R]
Fn3 = Callable[[T | S | U], R]
Fn = Callable[[T], R]
FnAny = Callable[[Any], R]
@ -629,7 +629,7 @@ def map_only(
def map_only(
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], /
type_or_types_or_pred: TypeAny | Callable[[Any], bool], /
) -> MapOnlyFn[FnAny[Any]]:
"""
Suppose you are writing a tree_map over tensors, leaving everything
@ -677,7 +677,7 @@ def tree_map_only(
/,
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -687,7 +687,7 @@ def tree_map_only(
/,
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -697,7 +697,7 @@ def tree_map_only(
/,
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -707,7 +707,7 @@ def tree_map_only(
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -717,16 +717,16 @@ def tree_map_only(
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
def tree_map_only(
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
type_or_types_or_pred: TypeAny | Callable[[Any], bool],
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -737,7 +737,7 @@ def tree_map_only_(
/,
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -747,7 +747,7 @@ def tree_map_only_(
/,
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -757,7 +757,7 @@ def tree_map_only_(
/,
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -767,7 +767,7 @@ def tree_map_only_(
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -777,16 +777,16 @@ def tree_map_only_(
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
def tree_map_only_(
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
type_or_types_or_pred: TypeAny | Callable[[Any], bool],
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -794,7 +794,7 @@ def tree_map_only_(
def tree_all(
pred: Callable[[Any], bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))
@ -803,7 +803,7 @@ def tree_all(
def tree_any(
pred: Callable[[Any], bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))
@ -815,7 +815,7 @@ def tree_all_only(
/,
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -825,7 +825,7 @@ def tree_all_only(
/,
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -835,7 +835,7 @@ def tree_all_only(
/,
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -844,7 +844,7 @@ def tree_all_only(
/,
pred: FnAny[bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, type_or_types))
@ -856,7 +856,7 @@ def tree_any_only(
/,
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -866,7 +866,7 @@ def tree_any_only(
/,
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -876,7 +876,7 @@ def tree_any_only(
/,
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -885,7 +885,7 @@ def tree_any_only(
/,
pred: FnAny[bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, type_or_types))
@ -894,7 +894,7 @@ def tree_any_only(
def broadcast_prefix(
prefix_tree: PyTree,
full_tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any]:
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
@ -956,8 +956,8 @@ def broadcast_prefix(
def _broadcast_to_and_flatten(
tree: PyTree,
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[list[Any]]:
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any] | None:
if not _is_pytreespec_instance(treespec):
raise AssertionError(
f"_broadcast_to_and_flatten: Expected `treespec` to be instance of PyTreeSpec but got {type(treespec)}"
@ -969,7 +969,7 @@ def _broadcast_to_and_flatten(
return None
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str:
"""Serialize a treespec to a JSON string."""
if not _is_pytreespec_instance(treespec):
raise TypeError(
@ -1024,7 +1024,7 @@ class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final]
def tree_flatten_with_path(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]:
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
@ -1047,7 +1047,7 @@ def tree_flatten_with_path(
def tree_leaves_with_path(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[tuple[KeyPath, Any]]:
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
@ -1070,7 +1070,7 @@ def tree_map_with_path(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
"""Like :func:`tree_map`, but the provided callable takes an additional key path argument.

View File

@ -4,7 +4,7 @@ import functools
import traceback
import weakref
from collections.abc import Callable
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -140,7 +140,7 @@ def _get_stack_trace() -> str:
return "".join(summary.format())
def _maybe_get_autograd_trace() -> Optional[str]:
def _maybe_get_autograd_trace() -> str | None:
if torch._C._current_autograd_node() is not None:
tb = torch._C._current_autograd_node().metadata.get("traceback_") # type: ignore[attr-defined]
if tb:
@ -154,8 +154,8 @@ class _DebugCall:
def __init__(
self,
call_depth: int,
record: Optional[dict[str, Any]] = None,
log: Optional[dict[str, Any]] = None,
record: dict[str, Any] | None = None,
log: dict[str, Any] | None = None,
stack: bool = False,
) -> None:
self.call_depth = call_depth
@ -166,10 +166,10 @@ class _DebugCall:
# results from dispatch hooks
self.record = record
self.log = log
self.output_str: Optional[str] = None
self.output_str: str | None = None
def stringify_args(
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
) -> None:
"""
To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs.
@ -182,7 +182,7 @@ class _DebugCall:
self,
output: Any,
attributes: list[str],
tensor_memo: Optional[TensorIdTracker] = None,
tensor_memo: TensorIdTracker | None = None,
) -> None:
"""Store stringified version of call output in self.output_str"""
if tree_all(lambda x: x is None, output):
@ -213,11 +213,11 @@ class _OpCall(_DebugCall):
self.args = args
self.kwargs = kwargs
self.args_str: Optional[str] = None
self.kwargs_str: Optional[str] = None
self.args_str: str | None = None
self.kwargs_str: str | None = None
def stringify_args(
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
) -> None:
self.args_str = ", ".join(
_arg_to_str(arg, attributes, tensor_memo) for arg in self.args
@ -289,10 +289,10 @@ class _RedistributeCall(_DebugCall):
self.dst_placement = dst_placement
self.transform_info_str = transform_info_str
self.arg_str: Optional[str] = None
self.arg_str: str | None = None
def stringify_args(
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
) -> None:
self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}"
del self.arg
@ -339,7 +339,7 @@ class _NNModuleCall(_DebugCall):
self.module_name = module_name
def stringify_args(
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
) -> None:
pass # nothing to stringify
@ -418,7 +418,7 @@ class DebugMode(TorchDispatchMode):
# This flag currently has no effect on torch.compiled-regions.
self.record_nn_module = record_nn_module
self.module_tracker: Optional[ModTracker] = None
self.module_tracker: ModTracker | None = None
if self.record_nn_module:
self.module_tracker_setup()
@ -585,7 +585,7 @@ class DebugMode(TorchDispatchMode):
arg,
src_placement,
dst_placement,
transform_info_str: Optional[str] = None,
transform_info_str: str | None = None,
):
try:
self._record_call(
@ -615,8 +615,8 @@ class DebugMode(TorchDispatchMode):
@staticmethod
@contextlib.contextmanager
def dispatch_hooks(
record_hook: Optional[Callable] = None,
log_hook: Optional[Callable] = None,
record_hook: Callable | None = None,
log_hook: Callable | None = None,
):
"""
Allows installing post-hooks on arguments to intercepted __torch_dispatch__ calls;
@ -660,9 +660,7 @@ class DebugMode(TorchDispatchMode):
@staticmethod
@contextlib.contextmanager
def log_tensor_hashes(
hash_fn: Optional[Callable] = None, hash_inputs: bool = False
):
def log_tensor_hashes(hash_fn: Callable | None = None, hash_inputs: bool = False):
"""
Installs hook for tensor hash logging.
@ -696,7 +694,7 @@ class DebugMode(TorchDispatchMode):
yield
def get_active_debug_mode() -> Optional[DebugMode]:
def get_active_debug_mode() -> DebugMode | None:
debug_mode = None
for mode in _get_current_dispatch_mode_stack():
if isinstance(mode, DebugMode):

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import functools
from typing import Optional
import torch
from torch._C import _len_torch_function_stack
@ -8,7 +7,7 @@ from torch.overrides import _pop_mode, _push_mode, TorchFunctionMode
from torch.utils._contextlib import context_decorator
CURRENT_DEVICE: Optional[torch.device] = None
CURRENT_DEVICE: torch.device | None = None
@functools.lru_cache(1)

View File

@ -1,5 +1,4 @@
from types import TracebackType
from typing import Optional
from typing_extensions import Self
from filelock import FileLock as base_FileLock
@ -28,9 +27,9 @@ class FileLock(base_FileLock):
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.region_counter.__exit__()
with _WaitCounter("pytorch.filelock.exit").guard():

View File

@ -1,4 +1,4 @@
from typing import Optional, TypeAlias
from typing import TypeAlias
import torch
from torch import Tensor
@ -23,7 +23,7 @@ def _get_fused_kernels_supported_devices() -> list[str]:
]
TensorListList: TypeAlias = list[list[Optional[Tensor]]]
TensorListList: TypeAlias = list[list[Tensor | None]]
Indices: TypeAlias = list[int]
_foreach_supported_types = [torch.Tensor]

View File

@ -1,7 +1,6 @@
import functools
import importlib.util
from types import ModuleType
from typing import Optional
def _check_module_exists(name: str) -> bool:
@ -24,7 +23,7 @@ def dill_available() -> bool:
@functools.lru_cache
def import_dill() -> Optional[ModuleType]:
def import_dill() -> ModuleType | None:
if not dill_available():
return None

View File

@ -8,7 +8,7 @@ from collections.abc import (
Reversible,
Set as AbstractSet,
)
from typing import Any, cast, Optional, TypeVar
from typing import Any, cast, TypeVar
T = TypeVar("T", bound=Hashable)
@ -24,7 +24,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
__slots__ = ("_dict",)
def __init__(self, iterable: Optional[Iterable[T]] = None) -> None:
def __init__(self, iterable: Iterable[T] | None = None) -> None:
self._dict = dict.fromkeys(iterable, None) if iterable is not None else {}
@staticmethod

View File

@ -6,7 +6,7 @@ import functools
import warnings
from collections import deque
from dataclasses import dataclass
from typing import cast, Optional, overload, Protocol, TYPE_CHECKING, Union
from typing import cast, overload, Protocol, TYPE_CHECKING
from typing_extensions import TypeIs
import torch
@ -207,7 +207,7 @@ class TorchDispatchMode:
return False
def _get_current_dispatch_mode() -> Optional[TorchDispatchMode]:
def _get_current_dispatch_mode() -> TorchDispatchMode | None:
"""
Return the top user mode on the stack (the next one that would be
executed) if there are any.
@ -308,7 +308,7 @@ def _push_mode(mode: TorchDispatchMode) -> None:
_set_mode_pre_dispatch(mode)
def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None):
def _pop_mode(k: DispatchKey | torch._C._TorchDispatchModeKey | None = None):
if k == torch._C.DispatchKey.PreDispatch: # type: ignore[attr-defined]
from torch._ops import _pop_mode_from_pre_dispatch
@ -319,7 +319,7 @@ def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] =
@contextlib.contextmanager
def _pop_mode_temporarily(k: Optional[DispatchKey] = None):
def _pop_mode_temporarily(k: DispatchKey | None = None):
old = _pop_mode(k)
try:
yield old
@ -429,18 +429,18 @@ class TensorWithFlatten(Protocol):
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None,
memory_format: torch.memory_format | None = None,
) -> torch.Tensor: ...
@overload
def to(
self,
device: Optional[torch._prims_common.DeviceLikeType] = None,
dtype: Optional[torch.types._dtype] = None,
device: torch._prims_common.DeviceLikeType | None = None,
dtype: torch.types._dtype | None = None,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None,
memory_format: torch.memory_format | None = None,
) -> torch.Tensor: ...
@overload
@ -450,7 +450,7 @@ class TensorWithFlatten(Protocol):
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None,
memory_format: torch.memory_format | None = None,
) -> torch.Tensor: ...
@ -610,7 +610,7 @@ def _correct_storage_aliasing(func, schema_info, args, outs) -> None:
alias_non_inplace_storage(args[arg_idx], outs[return_idx])
def _get_write_alias(x) -> Optional[str]:
def _get_write_alias(x) -> str | None:
alias_set = x.alias_set
if not alias_set or not x.is_write:
return None
@ -629,7 +629,7 @@ def _get_write_alias(x) -> Optional[str]:
class AliasInfo:
alias_set: set[str]
is_write: bool
name: Optional[str]
name: str | None
@dataclass
@ -642,7 +642,7 @@ class SchemaInfo:
# [_get_write_alias(x) for x in outs]. Guaranteed to contain no Nones; we coerce
# all-Nones result to empty list instead, and we don't support
# some-but-not-all-Nones.
outs_write_aliases: Optional[list[str]]
outs_write_aliases: list[str] | None
# List of (arg_idx, return_idx) where args[arg_idx].alias_set &
# outs[out_idx].alias_set is not empty, and not args[arg_idx].is_write.
@ -726,12 +726,12 @@ def get_alias_info(func) -> SchemaInfo:
if is_read_only_alias_match:
read_only_alias_match_indexes.append((arg_idx, return_idx))
outs_write_aliases_list: list[Optional[str]] = [
outs_write_aliases_list: list[str | None] = [
_get_write_alias(r) for r in out_schemas
]
non_nones = sum(x is not None for x in outs_write_aliases_list)
if non_nones == 0:
outs_write_aliases: Optional[list[str]] = None
outs_write_aliases: list[str] | None = None
elif non_nones != len(outs_write_aliases_list):
# simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
raise RuntimeError("Unsupported schema: " + str(func._schema))
@ -751,7 +751,7 @@ def get_alias_info(func) -> SchemaInfo:
def autograd_would_have_decomposed(
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
func: torch._ops.OpOverload, flat_args: Sequence[torch.Tensor | object]
) -> bool:
"""
Suppose that an operator has CompositeImplicitAutograd decomp registered.

View File

@ -33,7 +33,6 @@ from typing import (
Final,
Generic,
NoReturn,
Optional,
overload,
Protocol,
TypeAlias,
@ -109,7 +108,7 @@ class KeyEntry(Protocol):
class EnumEncoder(json.JSONEncoder):
def default(self, obj: object) -> Union[str, dict[str, Any]]:
def default(self, obj: object) -> str | dict[str, Any]:
if isinstance(obj, Enum):
return {
"__enum__": True,
@ -127,7 +126,7 @@ DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
ToStrFunc = Callable[["TreeSpec", list[str]], str]
MaybeFromStrFunc = Callable[[str], Optional[tuple[Any, Context, str]]]
MaybeFromStrFunc = Callable[[str], tuple[Any, Context, str] | None]
KeyPath = tuple[KeyEntry, ...]
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
@ -145,7 +144,7 @@ class NodeDef(NamedTuple):
type: type[Any]
flatten_fn: FlattenFunc
unflatten_fn: UnflattenFunc
flatten_with_keys_fn: Optional[FlattenWithKeysFunc]
flatten_with_keys_fn: FlattenWithKeysFunc | None
_NODE_REGISTRY_LOCK = threading.RLock()
@ -162,8 +161,8 @@ SUPPORTED_NODES: dict[type[Any], NodeDef] = {}
class _SerializeNodeDef(NamedTuple):
typ: type[Any]
serialized_type_name: str
to_dumpable_context: Optional[ToDumpableContextFn]
from_dumpable_context: Optional[FromDumpableContextFn]
to_dumpable_context: ToDumpableContextFn | None
from_dumpable_context: FromDumpableContextFn | None
SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {}
@ -199,10 +198,10 @@ def register_pytree_node(
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""Register a container-like type as pytree node.
@ -273,9 +272,9 @@ def register_pytree_node(
def register_dataclass(
cls: type[Any],
*,
field_names: Optional[list[str]] = None,
drop_field_names: Optional[list[str]] = None,
serialized_type_name: Optional[str] = None,
field_names: list[str] | None = None,
drop_field_names: list[str] | None = None,
serialized_type_name: str | None = None,
) -> None:
"""
Registers a type that has the semantics of a ``dataclasses.dataclass`` type
@ -524,13 +523,13 @@ def _register_pytree_node(
cls: type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
to_str_fn: Optional[ToStrFunc] = None, # deprecated
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
to_str_fn: ToStrFunc | None = None, # deprecated
maybe_from_str_fn: MaybeFromStrFunc | None = None, # deprecated
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""Register a container-like type as pytree node for the Python pytree only.
@ -594,10 +593,10 @@ def _private_register_pytree_node(
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
serialized_type_name: str | None = None,
to_dumpable_context: ToDumpableContextFn | None = None,
from_dumpable_context: FromDumpableContextFn | None = None,
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
) -> None:
"""This is an internal function that is used to register a pytree node type
for the Python pytree only. End-users should use :func:`register_pytree_node`
@ -671,7 +670,7 @@ class GetAttrKey:
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
def is_namedtuple(obj: Union[object, type]) -> bool:
def is_namedtuple(obj: object | type) -> bool:
"""Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
cls = obj if isinstance(obj, type) else type(obj)
return is_namedtuple_class(cls)
@ -723,7 +722,7 @@ class structseq(tuple[_T_co, ...]):
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
def is_structseq(obj: Union[object, type]) -> bool:
def is_structseq(obj: object | type) -> bool:
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
cls = obj if isinstance(obj, type) else type(obj)
return is_structseq_class(cls)
@ -1046,7 +1045,7 @@ def _get_node_type(tree: Any) -> Any:
# A leaf is defined as anything that is not a Node.
def tree_is_leaf(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
"""Check if a pytree is a leaf.
@ -1073,7 +1072,7 @@ def tree_is_leaf(
"Please use torch.utils._pytree.tree_is_leaf instead.",
category=FutureWarning,
)
def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool:
def _is_leaf(tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None) -> bool:
return tree_is_leaf(tree, is_leaf=is_leaf)
@ -1353,7 +1352,7 @@ def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
def treespec_dict(
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
mapping: Mapping[Any, TreeSpec] | Iterable[tuple[Any, TreeSpec]] = (),
/,
**kwargs: TreeSpec,
) -> TreeSpec:
@ -1366,7 +1365,7 @@ def treespec_dict(
def tree_flatten(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> tuple[list[Any], TreeSpec]:
"""Flattens a pytree into a list of values and a TreeSpec that can be used
to reconstruct the pytree.
@ -1404,7 +1403,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
def tree_iter(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> Iterable[Any]:
"""Get an iterator over the leaves of a pytree."""
if tree_is_leaf(tree, is_leaf=is_leaf):
@ -1421,7 +1420,7 @@ def tree_iter(
def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any]:
"""Get a list of leaves of a pytree."""
return list(tree_iter(tree, is_leaf=is_leaf))
@ -1429,7 +1428,7 @@ def tree_leaves(
def tree_structure(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> TreeSpec:
"""Get the TreeSpec for a pytree."""
return tree_flatten(tree, is_leaf=is_leaf)[1]
@ -1439,7 +1438,7 @@ def tree_map(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
"""Map a multi-input function over pytree args to produce a new pytree.
@ -1483,7 +1482,7 @@ def tree_map_(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
@ -1517,8 +1516,8 @@ Type2 = tuple[type[T], type[S]]
Type3 = tuple[type[T], type[S], type[U]]
TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
Fn2 = Callable[[Union[T, S]], R]
Fn3 = Callable[[Union[T, S, U]], R]
Fn2 = Callable[[T | S], R]
Fn3 = Callable[[T | S | U], R]
Fn = Callable[[T], R]
FnAny = Callable[[Any], R]
@ -1553,7 +1552,7 @@ def map_only(
def map_only(
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], /
type_or_types_or_pred: TypeAny | Callable[[Any], bool], /
) -> MapOnlyFn[FnAny[Any]]:
"""
Suppose you are writing a tree_map over tensors, leaving everything
@ -1601,7 +1600,7 @@ def tree_map_only(
/,
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -1611,7 +1610,7 @@ def tree_map_only(
/,
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -1621,7 +1620,7 @@ def tree_map_only(
/,
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -1631,7 +1630,7 @@ def tree_map_only(
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -1641,16 +1640,16 @@ def tree_map_only(
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
def tree_map_only(
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
type_or_types_or_pred: TypeAny | Callable[[Any], bool],
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -1661,7 +1660,7 @@ def tree_map_only_(
/,
func: Fn[T, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -1671,7 +1670,7 @@ def tree_map_only_(
/,
func: Fn2[T, S, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -1681,7 +1680,7 @@ def tree_map_only_(
/,
func: Fn3[T, S, U, Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -1691,7 +1690,7 @@ def tree_map_only_(
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
@ -1701,16 +1700,16 @@ def tree_map_only_(
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree: ...
def tree_map_only_(
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
type_or_types_or_pred: TypeAny | Callable[[Any], bool],
/,
func: FnAny[Any],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -1718,7 +1717,7 @@ def tree_map_only_(
def tree_all(
pred: Callable[[Any], bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(map(pred, flat_args))
@ -1727,7 +1726,7 @@ def tree_all(
def tree_any(
pred: Callable[[Any], bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(map(pred, flat_args))
@ -1739,7 +1738,7 @@ def tree_all_only(
/,
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -1749,7 +1748,7 @@ def tree_all_only(
/,
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -1759,7 +1758,7 @@ def tree_all_only(
/,
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -1768,7 +1767,7 @@ def tree_all_only(
/,
pred: FnAny[bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
return all(pred(x) for x in flat_args if isinstance(x, type_or_types))
@ -1780,7 +1779,7 @@ def tree_any_only(
/,
pred: Fn[T, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -1790,7 +1789,7 @@ def tree_any_only(
/,
pred: Fn2[T, S, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -1800,7 +1799,7 @@ def tree_any_only(
/,
pred: Fn3[T, S, U, bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool: ...
@ -1809,7 +1808,7 @@ def tree_any_only(
/,
pred: FnAny[bool],
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> bool:
flat_args = tree_iter(tree, is_leaf=is_leaf)
return any(pred(x) for x in flat_args if isinstance(x, type_or_types))
@ -1826,8 +1825,8 @@ def tree_any_only(
def _broadcast_to_and_flatten(
tree: PyTree,
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[list[Any]]:
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[Any] | None:
if not isinstance(treespec, TreeSpec):
raise AssertionError("treespec must be a TreeSpec")
@ -1868,7 +1867,7 @@ class _TreeSpecSchema:
- children_spec: A list of children serialized specs.
"""
type: Optional[str]
type: str | None
context: DumpableContext
children_spec: list["_TreeSpecSchema"]
@ -1917,7 +1916,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
def enum_object_hook(obj: dict[str, Any]) -> Enum | dict[str, Any]:
if "__enum__" in obj:
modname, _, classname = obj["fqn"].partition(":")
mod = importlib.import_module(modname)
@ -1968,7 +1967,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str:
if not isinstance(treespec, TreeSpec):
raise TypeError(
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
@ -2048,7 +2047,7 @@ def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> list[Any]:
def tree_flatten_with_path(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]:
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
@ -2072,7 +2071,7 @@ def tree_flatten_with_path(
def tree_leaves_with_path(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> list[tuple[KeyPath, Any]]:
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
@ -2094,7 +2093,7 @@ def tree_leaves_with_path(
def _generate_key_paths(
key_path: KeyPath,
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> Iterable[tuple[KeyPath, Any]]:
if is_leaf and is_leaf(tree):
yield key_path, tree
@ -2124,7 +2123,7 @@ def tree_map_with_path(
func: Callable[..., Any],
tree: PyTree,
*rests: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
is_leaf: Callable[[PyTree], bool] | None = None,
) -> PyTree:
"""Like :func:`tree_map`, but the provided callable takes an additional key path argument.

View File

@ -8,7 +8,7 @@ import subprocess
import time
from collections.abc import Callable, Sequence
from threading import Lock
from typing import Any, Optional, TypeVar
from typing import Any, TypeVar
from typing_extensions import ParamSpec
@ -34,14 +34,14 @@ class StrobelightCLIProfilerError(Exception):
"""
def _pid_namespace_link(pid: Optional[int] = None) -> str:
def _pid_namespace_link(pid: int | None = None) -> str:
"""Returns the link to the process's namespace, example: pid:[4026531836]"""
PID_NAMESPACE_PATH = "/proc/{}/ns/pid"
pid = pid or os.getpid()
return os.readlink(PID_NAMESPACE_PATH.format(pid))
def _pid_namespace(pid: Optional[int] = None) -> int:
def _pid_namespace(pid: int | None = None) -> int:
"""Returns the process's namespace id"""
pid = pid or os.getpid()
link = _pid_namespace_link(pid)
@ -77,8 +77,8 @@ class StrobelightCLIFunctionProfiler:
run_user_name: str = "pytorch-strobelight-ondemand",
timeout_wait_for_running_sec: int = 60,
timeout_wait_for_finished_sec: int = 60,
recorded_env_variables: Optional[list[str]] = None,
sample_tags: Optional[list[str]] = None,
recorded_env_variables: list[str] | None = None,
sample_tags: list[str] | None = None,
stack_max_len: int = 127,
async_stack_max_len: int = 127,
) -> None:
@ -90,7 +90,7 @@ class StrobelightCLIFunctionProfiler:
self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec
# Results of the most recent run.
# Tracks the strobelight run id of the most recent run
self.current_run_id: Optional[int] = None
self.current_run_id: int | None = None
self.sample_tags = sample_tags
def _run_async(self) -> None:
@ -253,7 +253,7 @@ class StrobelightCLIFunctionProfiler:
def profile(
self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
) -> Optional[_R]:
) -> _R | None:
self.current_run_id = None
if locked := StrobelightCLIFunctionProfiler._lock.acquire(False):
@ -295,16 +295,16 @@ class StrobelightCLIFunctionProfiler:
# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..))
# @strobelight(stop_at_error=True,...)
def strobelight(
profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any
) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]:
profiler: StrobelightCLIFunctionProfiler | None = None, **kwargs: Any
) -> Callable[[Callable[_P, _R]], Callable[_P, _R | None]]:
if not profiler:
profiler = StrobelightCLIFunctionProfiler(**kwargs)
def strobelight_inner(
work_function: Callable[_P, _R],
) -> Callable[_P, Optional[_R]]:
) -> Callable[_P, _R | None]:
@functools.wraps(work_function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _R | None:
# pyrefly: ignore [bad-argument-type]
return profiler.profile(work_function, *args, **kwargs)

View File

@ -4,7 +4,7 @@ import math
import operator
import sys
from collections.abc import Callable
from typing import Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union
from typing import SupportsFloat, TYPE_CHECKING, TypeVar
from typing_extensions import TypeVarTuple, Unpack
import sympy
@ -102,11 +102,11 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
def _keep_float(
f: Callable[[Unpack[_Ts]], _T],
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
) -> Callable[[Unpack[_Ts]], _T | sympy.Float]:
@functools.wraps(f)
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
def inner(*args: Unpack[_Ts]) -> _T | sympy.Float:
# pyrefly: ignore [bad-argument-type]
r: Union[_T, sympy.Float] = f(*args)
r: _T | sympy.Float = f(*args)
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
r, sympy.Float
):
@ -117,7 +117,7 @@ def _keep_float(
return inner
def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]:
def fuzzy_eq(x: bool | None, y: bool | None) -> bool | None:
if None in (x, y):
return None
return x == y
@ -216,9 +216,7 @@ class FloorDiv(sympy.Function):
# Automatic evaluation.
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
@classmethod
def eval(
cls, base: sympy.Integer, divisor: sympy.Integer
) -> Union[sympy.Basic, None]:
def eval(cls, base: sympy.Integer, divisor: sympy.Integer) -> sympy.Basic | None:
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
# Assert triggered by inequality solver
# assert base.is_integer, base
@ -324,7 +322,7 @@ class ModularIndexing(sympy.Function):
@classmethod
def eval(
cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer
) -> Optional[sympy.Basic]:
) -> sympy.Basic | None:
if base == 0 or modulus == 1:
return sympy.S.Zero
if (
@ -373,7 +371,7 @@ class ModularIndexing(sympy.Function):
return None
def _eval_is_nonnegative(self) -> Optional[bool]:
def _eval_is_nonnegative(self) -> bool | None:
# pyrefly: ignore [missing-attribute]
p, q = self.args[:2]
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
@ -387,23 +385,21 @@ class Where(sympy.Function):
nargs: tuple[int, ...] = (3,)
precedence: int = 35 # lower precedence than add
def _eval_is_integer(self) -> Optional[bool]:
def _eval_is_integer(self) -> bool | None:
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
def _eval_is_nonnegative(self) -> Optional[bool]:
def _eval_is_nonnegative(self) -> bool | None:
return (
True
if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
else None
)
def _eval_is_positive(self) -> Optional[bool]:
def _eval_is_positive(self) -> bool | None:
return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
@classmethod
def eval(
cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic
) -> Optional[sympy.Basic]:
def eval(cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic) -> sympy.Basic | None:
if c == sympy.true:
return p
elif c == sympy.false:
@ -419,7 +415,7 @@ class PythonMod(sympy.Function):
is_integer: bool = True
@classmethod
def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]:
def eval(cls, p: sympy.Expr, q: sympy.Expr) -> sympy.Expr | None:
# python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
# Triggered by sympy.solvers.inequalities.reduce_inequalities
# assert p.is_integer, p
@ -465,10 +461,10 @@ class PythonMod(sympy.Function):
return None
# NB: args[1] for PythonMod
def _eval_is_nonnegative(self) -> Optional[bool]:
def _eval_is_nonnegative(self) -> bool | None:
return True if self.args[1].is_positive else None # type: ignore[attr-defined]
def _eval_is_nonpositive(self) -> Optional[bool]:
def _eval_is_nonpositive(self) -> bool | None:
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
def _ccode(self, printer) -> str:
@ -664,7 +660,7 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
@classmethod
def _satisfy_unique_summations_symbols(
cls, args
) -> Optional[set[sympy.core.symbol.Symbol]]:
) -> set[sympy.core.symbol.Symbol] | None:
"""
One common case in some models is building expressions of the form
max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...).
@ -719,8 +715,8 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
@classmethod
def _unique_symbols(
cls, args, initial_set: Optional[set[sympy.core.symbol.Symbol]] = None
) -> Optional[set[sympy.core.symbol.Symbol]]:
cls, args, initial_set: set[sympy.core.symbol.Symbol] | None = None
) -> set[sympy.core.symbol.Symbol] | None:
"""
Return seen_symbols if all atoms in all args are all unique symbols,
else returns None. initial_set can be used to represent initial value for seen_symbols

View File

@ -10,7 +10,7 @@ of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
import functools
import logging
from typing import Any, Union
from typing import Any
import sympy
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
@ -184,7 +184,7 @@ _nil = object()
def sympy_interp(
analysis,
env: dict[sympy.Symbol, Any],
expr: Union[sympy.Expr, SympyBoolean],
expr: sympy.Expr | SympyBoolean,
*,
index_dtype=torch.int64,
missing_handler=None,

View File

@ -1,5 +1,4 @@
import sys
from typing import Optional
import sympy
from sympy.printing.precedence import PRECEDENCE, precedence
@ -23,7 +22,7 @@ class ExprPrinter(StrPrinter):
def _print_Not(self, expr: sympy.Expr) -> str:
return f"not ({self._print(expr.args[0])})"
def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
def _print_Add(self, expr: sympy.Expr, order: str | None = None) -> str:
return self.stringify(expr.args, " + ", precedence(expr))
def _print_Relational(self, expr: sympy.Expr) -> str:
@ -310,7 +309,7 @@ class PythonPrinter(ExprPrinter):
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
result: Optional[str] = None
result: str | None = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self._print(expr_i)
if cond_i == True: # noqa: E712
@ -349,7 +348,7 @@ class CppPrinter(ExprPrinter):
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
result: Optional[str] = None
result: str | None = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
if cond_i == True: # noqa: E712

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import math
import operator
from typing import NoReturn, Union
from typing import NoReturn
import sympy
@ -359,7 +359,7 @@ class TensorReferenceAnalysis:
# function isn't traced correctly. Here for completeness.
@staticmethod
def constant(c, dtype):
d: Union[int, float, bool]
d: int | float | bool
if dtype is torch.int64:
d = int(c)
elif dtype is torch.double:

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
import sympy
@ -20,7 +19,7 @@ _MIRROR_REL_OP: dict[type[sympy.Basic], type[sympy.Rel]] = {
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]:
def mirror_rel_op(type: type) -> type[sympy.Rel] | None:
return _MIRROR_REL_OP.get(type)
@ -43,7 +42,7 @@ def try_solve(
thing: sympy.Basic,
trials: int = 5,
floordiv_inequality: bool = True,
) -> Optional[tuple[sympy.Rel, sympy.Expr]]:
) -> tuple[sympy.Rel, sympy.Expr] | None:
mirror = mirror_rel_op(type(expr))
# Ignore unsupported expressions:

View File

@ -14,7 +14,6 @@ in this file and seeing what breaks.
from collections.abc import Iterable
from enum import auto, Enum
from typing import Union
import sympy
@ -88,7 +87,7 @@ def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
# This type is a little wider than it should be, because free_symbols says
# that it contains Basic, rather than Symbol
def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> bool:
def symbol_is_type(sym: sympy.Basic, prefix: SymT | Iterable[SymT]) -> bool:
if not isinstance(sym, sympy.Symbol):
raise AssertionError("expected sympy.Symbol")
name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
@ -98,5 +97,5 @@ def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> boo
return name_str.startswith(tuple(prefix_str[p] for p in prefix))
def free_symbol_is_type(e: sympy.Expr, prefix: Union[SymT, Iterable[SymT]]) -> bool:
def free_symbol_is_type(e: sympy.Expr, prefix: SymT | Iterable[SymT]) -> bool:
return any(symbol_is_type(v, prefix) for v in e.free_symbols)

View File

@ -10,7 +10,6 @@ import operator
from collections.abc import Callable
from typing import (
Generic,
Optional,
overload,
SupportsFloat,
TYPE_CHECKING,
@ -325,16 +324,16 @@ class ValueRanges(Generic[_T]):
@overload
@staticmethod
# work around the fact that bool and int overlap
def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap]
def wrap(arg: ExprIn | ExprVR) -> ExprVR: # type: ignore[overload-overlap]
...
@overload
@staticmethod
def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR: # type: ignore[misc]
def wrap(arg: BoolIn | BoolVR) -> BoolVR: # type: ignore[misc]
...
@staticmethod
def wrap(arg: Union[AllIn, AllVR]) -> AllVR:
def wrap(arg: AllIn | AllVR) -> AllVR:
if isinstance(arg, ValueRanges):
return arg
if isinstance(arg, float) and math.isnan(arg):
@ -343,29 +342,29 @@ class ValueRanges(Generic[_T]):
return ValueRanges(arg, arg) # type: ignore[arg-type]
@staticmethod
def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
def increasing_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR:
"""Increasing: x <= y => f(x) <= f(y)."""
x = ValueRanges.wrap(x)
return ValueRanges(fn(x.lower), fn(x.upper))
@overload
@staticmethod
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ...
def decreasing_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: ...
@overload
@staticmethod
def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR: # type: ignore[misc]
def decreasing_map(x: BoolIn | BoolVR, fn: BoolFn) -> BoolVR: # type: ignore[misc]
...
@staticmethod
def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR:
def decreasing_map(x: AllIn | AllVR, fn: AllFn) -> AllVR:
"""Decreasing: x <= y => f(x) >= f(y)."""
x = ValueRanges.wrap(x)
# consistently either Expr or Bool, but we don't know it here
return ValueRanges(fn(x.upper), fn(x.lower)) # type: ignore[arg-type]
@staticmethod
def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
def monotone_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR:
"""It's increasing or decreasing."""
x = ValueRanges.wrap(x)
l = fn(x.lower)
@ -373,7 +372,7 @@ class ValueRanges(Generic[_T]):
return ValueRanges(min(l, u), max(l, u))
@staticmethod
def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
def convex_min_zero_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR:
"""Fn is convex and has a minimum at 0."""
x = ValueRanges.wrap(x)
if 0 in x:
@ -387,23 +386,23 @@ class ValueRanges(Generic[_T]):
@overload
@staticmethod
def coordinatewise_increasing_map(
x: Union[ExprIn, ExprVR],
y: Union[ExprIn, ExprVR],
x: ExprIn | ExprVR,
y: ExprIn | ExprVR,
fn: ExprFn2,
) -> ExprVR: ...
@overload
@staticmethod
def coordinatewise_increasing_map( # type: ignore[misc]
x: Union[BoolIn, BoolVR],
y: Union[BoolIn, BoolVR],
x: BoolIn | BoolVR,
y: BoolIn | BoolVR,
fn: BoolFn2,
) -> BoolVR: ...
@staticmethod
def coordinatewise_increasing_map(
x: Union[AllIn, AllVR],
y: Union[AllIn, AllVR],
x: AllIn | AllVR,
y: AllIn | AllVR,
fn: AllFn2,
) -> AllVR:
"""
@ -1037,7 +1036,7 @@ class SymPyValueRangeAnalysis:
def bound_sympy(
expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None
expr: sympy.Expr, ranges: dict[sympy.Symbol, ValueRanges] | None = None
) -> ValueRanges:
log.debug(
"bound_sympy(%s)%s",

View File

@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Generic, Optional, TypeVar
from typing import Generic, TypeVar
R = TypeVar("R")
@ -12,8 +12,8 @@ class Thunk(Generic[R]):
function once it is forced.
"""
f: Optional[Callable[[], R]]
r: Optional[R]
f: Callable[[], R] | None
r: R | None
__slots__ = ["f", "r"]

View File

@ -5,7 +5,6 @@ import os.path
import tempfile
import traceback
from types import TracebackType
from typing import Optional
# This file contains utilities for ensuring dynamically compile()'d
@ -234,7 +233,7 @@ class CapturedTraceback:
import torch._C._profiler
# Directly populate tracebacks that already have cached summaries
rs: list[Optional[list[str]]] = []
rs: list[list[str] | None] = []
delayed_idxs = []
for i, tb in enumerate(tbs):
if tb.tb is None:

View File

@ -1,6 +1,6 @@
"""Miscellaneous utilities to aid with typing."""
from typing import Optional, TypeVar
from typing import TypeVar
# Helper to turn Optional[T] into T when we know None either isn't
@ -8,7 +8,7 @@ from typing import Optional, TypeVar
T = TypeVar("T")
def not_none(obj: Optional[T]) -> T:
def not_none(obj: T | None) -> T:
if obj is None:
raise TypeError("Invariant encountered: value was None when it should not be")
return obj

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend
@ -90,7 +89,7 @@ def _check_register_once(module, attr) -> None:
def _normalization_device(
custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None
custom_backend_name: str, device: int | str | torch.device | None = None
) -> int:
def _get_current_device_index():
_get_device_index = "current_device"
@ -137,7 +136,7 @@ def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -
def wrap_tensor_to(
self: torch.Tensor,
device: Optional[Union[int, torch.device]] = None,
device: int | torch.device | None = None,
non_blocking=False,
**kwargs,
) -> torch.Tensor:
@ -188,7 +187,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
def wrap_module_to(
self: torch.nn.modules.module.T,
device: Optional[Union[int, torch.device]] = None,
device: int | torch.device | None = None,
) -> torch.nn.modules.module.T:
r"""Move all model parameters and buffers to the custom device.
@ -268,7 +267,7 @@ def _generate_packed_sequence_methods_for_privateuse1_backend(
def _generate_storage_methods_for_privateuse1_backend(
custom_backend_name: str, unsupported_dtype: Optional[list[torch.dtype]] = None
custom_backend_name: str, unsupported_dtype: list[torch.dtype] | None = None
) -> None:
# Attribute is registered in the _StorageBase class
# and UntypedStorage obtains through inheritance.
@ -355,7 +354,7 @@ def generate_methods_for_privateuse1_backend(
for_module: bool = True,
for_packed_sequence: bool = True,
for_storage: bool = False,
unsupported_dtype: Optional[list[torch.dtype]] = None,
unsupported_dtype: list[torch.dtype] | None = None,
) -> None:
r"""
Automatically generate attributes and methods for the custom backend after rename privateuse1 backend.

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import numpy as np
import torch
@ -20,7 +20,7 @@ _POW_TWO_SIZES = tuple(2 ** i for i in range(
))
class UnaryOpSparseFuzzer(Fuzzer):
def __init__(self, seed: Optional[int], dtype: _dtype | None = None, cuda: bool = False) -> None:
def __init__(self, seed: int | None, dtype: _dtype | None = None, cuda: bool = False) -> None:
if dtype is None:
dtype = getattr(torch, 'float32', None)
super().__init__(

View File

@ -8,7 +8,7 @@ import shutil
import tempfile
import textwrap
import time
from typing import cast, Any, Optional
from typing import cast, Any
from collections.abc import Iterable, Iterator
import uuid
@ -34,10 +34,10 @@ class TaskSpec:
stmt: str
setup: str
global_setup: str = ""
label: Optional[str] = None
sub_label: Optional[str] = None
description: Optional[str] = None
env: Optional[str] = None
label: str | None = None
sub_label: str | None = None
description: str | None = None
env: str | None = None
num_threads: int = 1
@property
@ -82,7 +82,7 @@ class Measurement:
number_per_run: int
raw_times: list[float]
task_spec: TaskSpec
metadata: Optional[dict[Any, Any]] = None # Reserved for user payloads.
metadata: dict[Any, Any] | None = None # Reserved for user payloads.
def __post_init__(self) -> None:
self._sorted_times: tuple[float, ...] = ()
@ -297,7 +297,7 @@ def set_torch_threads(n: int) -> Iterator[None]:
torch.set_num_threads(prior_num_threads)
def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> str:
def _make_temp_dir(prefix: str | None = None, gc_dev_shm: bool = False) -> str:
"""Create a temporary directory. The caller is responsible for cleanup.
This function is conceptually similar to `tempfile.mkdtemp`, but with

Some files were not shown because too many files have changed in this diff Show More