mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 14:15:07 +08:00
Compare commits
19 Commits
viable/str
...
documentat
| Author | SHA1 | Date | |
|---|---|---|---|
| 27e0a198be | |||
| 256b61734f | |||
| 59307ca1bc | |||
| c28475db7c | |||
| 74aec83841 | |||
| 52e744d68a | |||
| 3cfbf98ea9 | |||
| 47db55258b | |||
| 50af6f3393 | |||
| e545ba2d34 | |||
| a058bbdd6f | |||
| 2c78080ec0 | |||
| fe6615e397 | |||
| abf31db2cc | |||
| a4c7856112 | |||
| afb014541b | |||
| b91a2ab892 | |||
| 14a845a4ec | |||
| 5135ace3a3 |
@ -260,8 +260,8 @@ case "$tag" in
|
||||
HALIDE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
|
||||
CUDA_VERSION=13.0.0
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
|
||||
CUDA_VERSION=12.8.1
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=11
|
||||
PALLAS=yes
|
||||
|
||||
@ -8,9 +8,11 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
try:
|
||||
from typing import Any, Callable, Required, TypedDict # Python 3.11+
|
||||
from collections.abc import Callable # Python 3.11+
|
||||
from typing import Any, Required, TypedDict
|
||||
except ImportError:
|
||||
from typing import Any, Callable, TypedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from typing_extensions import Required # Fallback for Python <3.11
|
||||
|
||||
|
||||
@ -168,14 +168,16 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/compiler/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/umf/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Enable XCCL build
|
||||
export USE_XCCL=1
|
||||
export USE_MPI=0
|
||||
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
|
||||
export USE_KINETO=0
|
||||
export TORCH_XPU_ARCH_LIST=pvc
|
||||
fi
|
||||
|
||||
|
||||
@ -208,6 +208,8 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Check XPU status before testing
|
||||
timeout 30 xpu-smi discovery || true
|
||||
fi
|
||||
@ -337,7 +339,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
|
||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
||||
|
||||
3
.github/scripts/delete_old_branches.py
vendored
3
.github/scripts/delete_old_branches.py
vendored
@ -1,10 +1,11 @@
|
||||
# Delete old branches
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from github_utils import gh_fetch_json_dict, gh_graphql
|
||||
from gitutils import GitRepo
|
||||
|
||||
3
.github/scripts/filter_test_configs.py
vendored
3
.github/scripts/filter_test_configs.py
vendored
@ -8,10 +8,11 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import cache
|
||||
from logging import info
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import yaml
|
||||
|
||||
3
.github/scripts/get_workflow_job_id.py
vendored
3
.github/scripts/get_workflow_job_id.py
vendored
@ -11,7 +11,8 @@ import sys
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
|
||||
3
.github/scripts/github_utils.py
vendored
3
.github/scripts/github_utils.py
vendored
@ -3,8 +3,9 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from typing import Any, cast, Optional, Union
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
4
.github/scripts/gitutils.py
vendored
4
.github/scripts/gitutils.py
vendored
@ -4,10 +4,10 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Callable, Iterator
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
from typing import Any, cast, Optional, TypeVar, Union
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
4
.github/scripts/trymerge.py
vendored
4
.github/scripts/trymerge.py
vendored
@ -17,12 +17,12 @@ import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
from warnings import warn
|
||||
|
||||
import yaml
|
||||
|
||||
2
.github/workflows/docker-builds.yml
vendored
2
.github/workflows/docker-builds.yml
vendored
@ -67,7 +67,7 @@ jobs:
|
||||
pytorch-linux-jammy-py3.10-gcc11,
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-noble-xpu-n-py3,
|
||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
||||
|
||||
4
.github/workflows/inductor-unittest.yml
vendored
4
.github/workflows/inductor-unittest.yml
vendored
@ -86,8 +86,8 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.8xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,6 +127,7 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
@ -94,6 +94,11 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -157,6 +157,8 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
||||
DispatchKey::Negative,
|
||||
DispatchKey::Conjugate,
|
||||
DispatchKey::XLA,
|
||||
DispatchKey::XPU,
|
||||
DispatchKey::HPU,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
|
||||
@ -4292,6 +4292,7 @@
|
||||
dispatch:
|
||||
SparseCPU: sparse_sparse_matmul_cpu
|
||||
SparseCUDA: sparse_sparse_matmul_cuda
|
||||
SparseMPS: sparse_sparse_matmul_mps
|
||||
autogen: _sparse_sparse_matmul.out
|
||||
|
||||
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
@ -9832,7 +9833,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method, function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: erfinv_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
|
||||
tags: pointwise
|
||||
|
||||
@ -9841,7 +9842,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
|
||||
tags: pointwise
|
||||
|
||||
@ -9851,7 +9852,7 @@
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: erfinv_out
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_out
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
|
||||
tags: pointwise
|
||||
|
||||
|
||||
@ -10,6 +10,10 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_coalesce_native.h>
|
||||
#include <ATen/ops/repeat_interleave_native.h>
|
||||
#include <ATen/ops/cumsum.h>
|
||||
#include <ATen/ops/_sparse_sparse_matmul_native.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/add_native.h>
|
||||
@ -888,5 +892,114 @@ static void sparse_mask_intersection_out_mps_kernel(
|
||||
/*coalesce_mask=*/false);
|
||||
}
|
||||
|
||||
Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be sparse COO tensors");
|
||||
TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be on MPS device");
|
||||
TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2,
|
||||
"sparse_sparse_matmul_mps: both inputs must be 2D matrices");
|
||||
TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0,
|
||||
"sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)");
|
||||
TORCH_CHECK(mat1_.size(1) == mat2_.size(0),
|
||||
"mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
|
||||
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
|
||||
"sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(),
|
||||
" does not match mat2 dtype ", mat2_.scalar_type());
|
||||
|
||||
const auto device = mat1_.device();
|
||||
|
||||
auto A = mat1_.coalesce();
|
||||
auto B = mat2_.coalesce();
|
||||
|
||||
const auto I = A.size(0);
|
||||
const auto K = A.size(1);
|
||||
const auto N = B.size(1);
|
||||
|
||||
const auto nnzA = A._nnz();
|
||||
const auto nnzB = B._nnz();
|
||||
|
||||
// Early empty result, return an empty, coalesced tensor
|
||||
if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
const auto computeDtype = at::result_type(mat1_, mat2_);
|
||||
|
||||
auto A_idx = A._indices().contiguous();
|
||||
auto A_val = A._values().to(computeDtype).contiguous();
|
||||
auto A_i = A_idx.select(0, 0).contiguous();
|
||||
auto A_k = A_idx.select(0, 1).contiguous();
|
||||
|
||||
auto B_idx = B._indices().contiguous();
|
||||
auto B_val = B._values().to(computeDtype).contiguous();
|
||||
auto B_k = B_idx.select(0, 0).contiguous();
|
||||
auto B_j = B_idx.select(0, 1).contiguous();
|
||||
|
||||
// csr-style row pointers for B by k (the shared dimension)
|
||||
Tensor row_ptr_B;
|
||||
{
|
||||
auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong));
|
||||
row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong));
|
||||
build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B);
|
||||
}
|
||||
|
||||
auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K);
|
||||
auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K);
|
||||
auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo);
|
||||
|
||||
auto counts = deg_B.index_select(0, A_k);
|
||||
|
||||
const int64_t P = counts.sum().item<int64_t>();
|
||||
if (P == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
auto group_ids = repeat_interleave_mps(counts);
|
||||
|
||||
// exclusive cumsum of counts
|
||||
auto offsets = cumsum(counts, /*dim=*/0).sub(counts);
|
||||
auto offsets_gather = offsets.index_select(0, group_ids);
|
||||
auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather);
|
||||
|
||||
// Map each output element to its source B row and position
|
||||
auto k_per_out = A_k.index_select(0, group_ids);
|
||||
auto start_in_B = row_ptr_B.index_select(0, k_per_out);
|
||||
auto seg_index = start_in_B.add(within);
|
||||
|
||||
// Assemble candidate coo pairs and values
|
||||
auto i_out = A_i.index_select(0, group_ids).contiguous();
|
||||
auto j_out = B_j.index_select(0, seg_index).contiguous();
|
||||
auto vA_out = A_val.index_select(0, group_ids).contiguous();
|
||||
auto vB_out = B_val.index_select(0, seg_index).contiguous();
|
||||
auto v_out = vA_out.mul(vB_out);
|
||||
|
||||
// build (2, P) indices
|
||||
auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous();
|
||||
out_indices.select(0, 0).copy_(i_out);
|
||||
out_indices.select(0, 1).copy_(j_out);
|
||||
|
||||
auto result = _sparse_coo_tensor_unsafe(
|
||||
out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype));
|
||||
|
||||
result = result.coalesce();
|
||||
|
||||
if (result.scalar_type() != mat1_.scalar_type()) {
|
||||
auto cast_vals = result._values().to(mat1_.scalar_type());
|
||||
auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -96,6 +96,10 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -345,6 +345,13 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<AllocatorState> pps) = 0;
|
||||
virtual std::string name() = 0;
|
||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return {free, total};
|
||||
}
|
||||
};
|
||||
|
||||
// Allocator object, statically initialized
|
||||
|
||||
@ -926,15 +926,14 @@ class DeviceCachingAllocator {
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
const auto device_total =
|
||||
raw_device.get_info<sycl::info::device::global_mem_size>();
|
||||
// Estimate the available device memory when the SYCL runtime does not
|
||||
// support the corresponding aspect (ext_intel_free_memory).
|
||||
size_t device_free = device_prop.global_mem_size -
|
||||
size_t device_free = device_total -
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||
// affected devices.
|
||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||
@ -1052,21 +1051,37 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo() {
|
||||
const auto& device = c10::xpu::get_raw_device(device_index);
|
||||
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
device.get_info<sycl::info::device::name>(),
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
const size_t free =
|
||||
device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
static_cast<double>(device_total);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||
set_fraction = true;
|
||||
}
|
||||
@ -1240,6 +1255,11 @@ class XPUAllocator : public DeviceAllocator {
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryInfo();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_memory_info
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
memory_allocated
|
||||
|
||||
@ -382,20 +382,6 @@ coverage_ignore_functions = [
|
||||
# torch.ao.quantization.backend_config.tensorrt
|
||||
"get_tensorrt_backend_config",
|
||||
"get_tensorrt_backend_config_dict",
|
||||
# torch.ao.quantization.backend_config.utils
|
||||
"entry_to_pretty_str",
|
||||
"get_fused_module_classes",
|
||||
"get_fuser_method_mapping",
|
||||
"get_fusion_pattern_to_extra_inputs_getter",
|
||||
"get_fusion_pattern_to_root_node_getter",
|
||||
"get_module_to_qat_module",
|
||||
"get_pattern_to_dtype_configs",
|
||||
"get_pattern_to_input_type_to_index",
|
||||
"get_qat_module_classes",
|
||||
"get_root_module_to_quantized_reference_module",
|
||||
"pattern_to_human_readable",
|
||||
"remove_boolean_dispatch_from_name",
|
||||
# torch.ao.quantization.backend_config.x86
|
||||
"get_x86_backend_config",
|
||||
# torch.ao.quantization.fuse_modules
|
||||
"fuse_known_modules",
|
||||
@ -426,25 +412,6 @@ coverage_ignore_functions = [
|
||||
"insert_observers_for_model",
|
||||
"prepare",
|
||||
"propagate_dtypes_for_known_nodes",
|
||||
# torch.ao.quantization.fx.utils
|
||||
"all_node_args_except_first",
|
||||
"all_node_args_have_no_tensors",
|
||||
"assert_and_get_unique_device",
|
||||
"collect_producer_nodes",
|
||||
"create_getattr_from_value",
|
||||
"create_node_from_old_node_preserve_meta",
|
||||
"get_custom_module_class_keys",
|
||||
"get_linear_prepack_op_for_dtype",
|
||||
"get_new_attr_name_with_prefix",
|
||||
"get_non_observable_arg_indexes_and_types",
|
||||
"get_qconv_prepack_op",
|
||||
"get_skipped_module_name_and_classes",
|
||||
"graph_module_from_producer_nodes",
|
||||
"maybe_get_next_module",
|
||||
"node_arg_is_bias",
|
||||
"node_arg_is_weight",
|
||||
"return_arg_list",
|
||||
# torch.ao.quantization.pt2e.graph_utils
|
||||
"bfs_trace_with_node_process",
|
||||
"find_sequential_partitions",
|
||||
"get_equivalent_types",
|
||||
@ -860,80 +827,10 @@ coverage_ignore_functions = [
|
||||
"get_latency_of_one_partition",
|
||||
"get_latency_of_partitioned_graph",
|
||||
"get_partition_to_latency_mapping",
|
||||
# torch.fx.experimental.proxy_tensor
|
||||
"decompose",
|
||||
"disable_autocast_cache",
|
||||
"disable_proxy_modes_tracing",
|
||||
"dispatch_trace",
|
||||
"extract_val",
|
||||
"fake_signature",
|
||||
"fetch_sym_proxy",
|
||||
"fetch_object_proxy",
|
||||
"get_innermost_proxy_mode",
|
||||
"get_isolated_graphmodule",
|
||||
"get_proxy_slot",
|
||||
"get_torch_dispatch_modes",
|
||||
"has_proxy_slot",
|
||||
"is_sym_node",
|
||||
"maybe_handle_decomp",
|
||||
"proxy_call",
|
||||
"set_meta",
|
||||
"set_original_aten_op",
|
||||
"set_proxy_slot",
|
||||
"snapshot_fake",
|
||||
"thunkify",
|
||||
"track_tensor",
|
||||
"track_tensor_tree",
|
||||
"wrap_key",
|
||||
"wrapper_and_args_for_make_fx",
|
||||
# torch.fx.experimental.recording
|
||||
"record_shapeenv_event",
|
||||
"replay_shape_env_events",
|
||||
"shape_env_check_state_equal",
|
||||
# torch.fx.experimental.sym_node
|
||||
"ceil_impl",
|
||||
"floor_ceil_helper",
|
||||
"floor_impl",
|
||||
"method_to_operator",
|
||||
"sympy_is_channels_last_contiguous_2d",
|
||||
"sympy_is_channels_last_contiguous_3d",
|
||||
"sympy_is_channels_last_strides_2d",
|
||||
"sympy_is_channels_last_strides_3d",
|
||||
"sympy_is_channels_last_strides_generic",
|
||||
"sympy_is_contiguous",
|
||||
"sympy_is_contiguous_generic",
|
||||
"to_node",
|
||||
"wrap_node",
|
||||
"sym_sqrt",
|
||||
# torch.fx.experimental.symbolic_shapes
|
||||
"bind_symbols",
|
||||
"cast_symbool_to_symint_guardless",
|
||||
"create_contiguous",
|
||||
"error",
|
||||
"eval_guards",
|
||||
"eval_is_non_overlapping_and_dense",
|
||||
"expect_true",
|
||||
"find_symbol_binding_fx_nodes",
|
||||
"free_symbols",
|
||||
"free_unbacked_symbols",
|
||||
"fx_placeholder_targets",
|
||||
"fx_placeholder_vals",
|
||||
"guard_bool",
|
||||
"guard_float",
|
||||
"guard_int",
|
||||
"guard_scalar",
|
||||
"has_hint",
|
||||
"has_symbolic_sizes_strides",
|
||||
"is_channels_last_contiguous_2d",
|
||||
"is_channels_last_contiguous_3d",
|
||||
"is_channels_last_strides_2d",
|
||||
"is_channels_last_strides_3d",
|
||||
"is_contiguous",
|
||||
"is_non_overlapping_and_dense_indicator",
|
||||
"is_nested_int",
|
||||
"is_symbol_binding_fx_node",
|
||||
"is_symbolic",
|
||||
# torch.fx.experimental.unification.core
|
||||
"reify",
|
||||
# torch.fx.experimental.unification.match
|
||||
"edge",
|
||||
@ -971,24 +868,6 @@ coverage_ignore_functions = [
|
||||
"reverse_dict",
|
||||
# torch.fx.experimental.unification.multipledispatch.variadic
|
||||
"isvariadic",
|
||||
# torch.fx.experimental.unification.unification_tools
|
||||
"assoc",
|
||||
"assoc_in",
|
||||
"dissoc",
|
||||
"first",
|
||||
"get_in",
|
||||
"getter",
|
||||
"groupby",
|
||||
"itemfilter",
|
||||
"itemmap",
|
||||
"keyfilter",
|
||||
"keymap",
|
||||
"merge",
|
||||
"merge_with",
|
||||
"update_in",
|
||||
"valfilter",
|
||||
"valmap",
|
||||
# torch.fx.experimental.unification.utils
|
||||
"freeze",
|
||||
"hashable",
|
||||
"raises",
|
||||
@ -1429,319 +1308,8 @@ coverage_ignore_functions = [
|
||||
# torch.onnx.symbolic_opset7
|
||||
"max",
|
||||
"min",
|
||||
# torch.onnx.symbolic_opset8
|
||||
"addmm",
|
||||
"bmm",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"flatten",
|
||||
"full",
|
||||
"full_like",
|
||||
"gt",
|
||||
"lt",
|
||||
"matmul",
|
||||
"mm",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"prelu",
|
||||
"repeat",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.symbolic_opset9
|
||||
"abs",
|
||||
"acos",
|
||||
"adaptive_avg_pool1d",
|
||||
"adaptive_avg_pool2d",
|
||||
"adaptive_avg_pool3d",
|
||||
"adaptive_max_pool1d",
|
||||
"adaptive_max_pool2d",
|
||||
"adaptive_max_pool3d",
|
||||
"add",
|
||||
"addcmul",
|
||||
"addmm",
|
||||
"alias",
|
||||
"amax",
|
||||
"amin",
|
||||
"aminmax",
|
||||
"arange",
|
||||
"argmax",
|
||||
"argmin",
|
||||
"as_strided",
|
||||
"as_tensor",
|
||||
"asin",
|
||||
"atan",
|
||||
"atan2",
|
||||
"avg_pool1d",
|
||||
"avg_pool2d",
|
||||
"avg_pool3d",
|
||||
"baddbmm",
|
||||
"batch_norm",
|
||||
"bernoulli",
|
||||
"bitwise_not",
|
||||
"bitwise_or",
|
||||
"bmm",
|
||||
"broadcast_tensors",
|
||||
"broadcast_to",
|
||||
"bucketize",
|
||||
"cat",
|
||||
"cdist",
|
||||
"ceil",
|
||||
"clamp",
|
||||
"clamp_max",
|
||||
"clamp_min",
|
||||
"clone",
|
||||
"constant_pad_nd",
|
||||
"contiguous",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"conv_tbc",
|
||||
"conv_transpose1d",
|
||||
"conv_transpose2d",
|
||||
"conv_transpose3d",
|
||||
"convert_element_type",
|
||||
"convolution",
|
||||
"cos",
|
||||
"cosine_similarity",
|
||||
"cross",
|
||||
"cumsum",
|
||||
"detach",
|
||||
"dim",
|
||||
"div",
|
||||
"dot",
|
||||
"dropout",
|
||||
"elu",
|
||||
"embedding",
|
||||
"embedding_bag",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"eq",
|
||||
"erf",
|
||||
"exp",
|
||||
"expand",
|
||||
"expand_as",
|
||||
"eye",
|
||||
"fill",
|
||||
"flatten",
|
||||
"floor",
|
||||
"floor_divide",
|
||||
"floordiv",
|
||||
"frobenius_norm",
|
||||
"full",
|
||||
"full_like",
|
||||
"gather",
|
||||
"ge",
|
||||
"gelu",
|
||||
"get_pool_ceil_padding",
|
||||
"glu",
|
||||
"group_norm",
|
||||
"gru",
|
||||
"gt",
|
||||
"hann_window",
|
||||
"hardshrink",
|
||||
"hardsigmoid",
|
||||
"hardswish",
|
||||
"hardtanh",
|
||||
"index",
|
||||
"index_add",
|
||||
"index_copy",
|
||||
"index_fill",
|
||||
"index_put",
|
||||
"index_select",
|
||||
"instance_norm",
|
||||
"is_floating_point",
|
||||
"is_pinned",
|
||||
"isnan",
|
||||
"item",
|
||||
"kl_div",
|
||||
"layer_norm",
|
||||
"le",
|
||||
"leaky_relu",
|
||||
"lerp",
|
||||
"lift",
|
||||
"linalg_cross",
|
||||
"linalg_matrix_norm",
|
||||
"linalg_norm",
|
||||
"linalg_vector_norm",
|
||||
"linear",
|
||||
"linspace",
|
||||
"log",
|
||||
"log10",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log_sigmoid",
|
||||
"log_softmax",
|
||||
"logical_and",
|
||||
"logical_not",
|
||||
"logical_or",
|
||||
"logical_xor",
|
||||
"logit",
|
||||
"logsumexp",
|
||||
"lstm",
|
||||
"lstm_cell",
|
||||
"lt",
|
||||
"masked_fill",
|
||||
"masked_fill_",
|
||||
"matmul",
|
||||
"max",
|
||||
"max_pool1d",
|
||||
"max_pool1d_with_indices",
|
||||
"max_pool2d",
|
||||
"max_pool2d_with_indices",
|
||||
"max_pool3d",
|
||||
"max_pool3d_with_indices",
|
||||
"maximum",
|
||||
"meshgrid",
|
||||
"min",
|
||||
"minimum",
|
||||
"mish",
|
||||
"mm",
|
||||
"movedim",
|
||||
"mse_loss",
|
||||
"mul",
|
||||
"multinomial",
|
||||
"mv",
|
||||
"narrow",
|
||||
"native_layer_norm",
|
||||
"ne",
|
||||
"neg",
|
||||
"new_empty",
|
||||
"new_full",
|
||||
"new_ones",
|
||||
"new_zeros",
|
||||
"nonzero",
|
||||
"nonzero_numpy",
|
||||
"noop_complex_operators",
|
||||
"norm",
|
||||
"numel",
|
||||
"numpy_T",
|
||||
"one_hot",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"onnx_placeholder",
|
||||
"overload_by_arg_count",
|
||||
"pad",
|
||||
"pairwise_distance",
|
||||
"permute",
|
||||
"pixel_shuffle",
|
||||
"pixel_unshuffle",
|
||||
"pow",
|
||||
"prelu",
|
||||
"prim_constant",
|
||||
"prim_constant_chunk",
|
||||
"prim_constant_split",
|
||||
"prim_data",
|
||||
"prim_device",
|
||||
"prim_dtype",
|
||||
"prim_if",
|
||||
"prim_layout",
|
||||
"prim_list_construct",
|
||||
"prim_list_unpack",
|
||||
"prim_loop",
|
||||
"prim_max",
|
||||
"prim_min",
|
||||
"prim_shape",
|
||||
"prim_tolist",
|
||||
"prim_tuple_construct",
|
||||
"prim_type",
|
||||
"prim_unchecked_cast",
|
||||
"prim_uninitialized",
|
||||
"rand",
|
||||
"rand_like",
|
||||
"randint",
|
||||
"randint_like",
|
||||
"randn",
|
||||
"randn_like",
|
||||
"reciprocal",
|
||||
"reflection_pad",
|
||||
"relu",
|
||||
"relu6",
|
||||
"remainder",
|
||||
"repeat",
|
||||
"repeat_interleave",
|
||||
"replication_pad",
|
||||
"reshape",
|
||||
"reshape_as",
|
||||
"rnn_relu",
|
||||
"rnn_tanh",
|
||||
"roll",
|
||||
"rrelu",
|
||||
"rsqrt",
|
||||
"rsub",
|
||||
"scalar_tensor",
|
||||
"scatter",
|
||||
"scatter_add",
|
||||
"select",
|
||||
"selu",
|
||||
"sigmoid",
|
||||
"sign",
|
||||
"silu",
|
||||
"sin",
|
||||
"size",
|
||||
"slice",
|
||||
"softmax",
|
||||
"softplus",
|
||||
"softshrink",
|
||||
"sort",
|
||||
"split",
|
||||
"split_with_sizes",
|
||||
"sqrt",
|
||||
"square",
|
||||
"squeeze",
|
||||
"stack",
|
||||
"std",
|
||||
"std_mean",
|
||||
"sub",
|
||||
"t",
|
||||
"take",
|
||||
"tan",
|
||||
"tanh",
|
||||
"tanhshrink",
|
||||
"tensor",
|
||||
"threshold",
|
||||
"to",
|
||||
"topk",
|
||||
"transpose",
|
||||
"true_divide",
|
||||
"type_as",
|
||||
"unbind",
|
||||
"unfold",
|
||||
"unsafe_chunk",
|
||||
"unsafe_split",
|
||||
"unsafe_split_with_sizes",
|
||||
"unsqueeze",
|
||||
"unsupported_complex_operators",
|
||||
"unused",
|
||||
"upsample_bilinear2d",
|
||||
"upsample_linear1d",
|
||||
"upsample_nearest1d",
|
||||
"upsample_nearest2d",
|
||||
"upsample_nearest3d",
|
||||
"upsample_trilinear3d",
|
||||
"var",
|
||||
"var_mean",
|
||||
"view",
|
||||
"view_as",
|
||||
"where",
|
||||
"wrap_logical_op_with_cast_to",
|
||||
"wrap_logical_op_with_negation",
|
||||
"zero",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.utils
|
||||
"disable_apex_o2_state_dict_hook",
|
||||
"export",
|
||||
"export_to_pretty_string",
|
||||
"exporter_context",
|
||||
"is_in_onnx_export",
|
||||
"model_signature",
|
||||
"register_custom_op_symbolic",
|
||||
"select_model_mode_for_export",
|
||||
"setup_onnx_logging",
|
||||
"unconvertible_ops",
|
||||
"unpack_quantized_tensor",
|
||||
"warn_on_static_input_change",
|
||||
# torch.onnx.verification
|
||||
"check_export_model_diff",
|
||||
"verify",
|
||||
"verify_aten_graph",
|
||||
@ -1832,32 +1400,6 @@ coverage_ignore_functions = [
|
||||
"noop_context_fn",
|
||||
"set_checkpoint_early_stop",
|
||||
"set_device_states",
|
||||
# torch.utils.collect_env
|
||||
"check_release_file",
|
||||
"get_cachingallocator_config",
|
||||
"get_clang_version",
|
||||
"get_cmake_version",
|
||||
"get_conda_packages",
|
||||
"get_cpu_info",
|
||||
"get_cuda_module_loading_config",
|
||||
"get_cudnn_version",
|
||||
"get_env_info",
|
||||
"get_gcc_version",
|
||||
"get_gpu_info",
|
||||
"get_libc_version",
|
||||
"get_lsb_version",
|
||||
"get_mac_version",
|
||||
"get_nvidia_driver_version",
|
||||
"get_nvidia_smi",
|
||||
"get_os",
|
||||
"get_pip_packages",
|
||||
"get_platform",
|
||||
"get_pretty_env_info",
|
||||
"get_python_platform",
|
||||
"get_running_cuda_version",
|
||||
"get_windows_version",
|
||||
"is_xnnpack_available",
|
||||
"pretty_str",
|
||||
# torch.utils.cpp_backtrace
|
||||
"get_cpp_backtrace",
|
||||
# torch.utils.cpp_extension
|
||||
@ -1921,52 +1463,6 @@ coverage_ignore_functions = [
|
||||
"apply_shuffle_seed",
|
||||
"apply_shuffle_settings",
|
||||
"get_all_graph_pipes",
|
||||
# torch.utils.flop_counter
|
||||
"addmm_flop",
|
||||
"baddbmm_flop",
|
||||
"bmm_flop",
|
||||
"conv_backward_flop",
|
||||
"conv_flop",
|
||||
"conv_flop_count",
|
||||
"convert_num_with_suffix",
|
||||
"get_shape",
|
||||
"get_suffix_str",
|
||||
"mm_flop",
|
||||
"normalize_tuple",
|
||||
"register_flop_formula",
|
||||
"sdpa_backward_flop",
|
||||
"sdpa_backward_flop_count",
|
||||
"sdpa_flop",
|
||||
"sdpa_flop_count",
|
||||
"shape_wrapper",
|
||||
"transpose_shape",
|
||||
# torch.utils.hipify.hipify_python
|
||||
"add_dim3",
|
||||
"compute_stats",
|
||||
"extract_arguments",
|
||||
"file_add_header",
|
||||
"file_specific_replacement",
|
||||
"find_bracket_group",
|
||||
"find_closure_group",
|
||||
"find_parentheses_group",
|
||||
"fix_static_global_kernels",
|
||||
"get_hip_file_path",
|
||||
"hip_header_magic",
|
||||
"hipify",
|
||||
"is_caffe2_gpu_file",
|
||||
"is_cusparse_file",
|
||||
"is_out_of_place",
|
||||
"is_pytorch_file",
|
||||
"is_special_file",
|
||||
"match_extensions",
|
||||
"matched_files_iter",
|
||||
"openf",
|
||||
"preprocess_file_and_save_result",
|
||||
"preprocessor",
|
||||
"processKernelLaunches",
|
||||
"replace_extern_shared",
|
||||
"replace_math_functions",
|
||||
"str2bool",
|
||||
# torch.utils.hooks
|
||||
"unserializable_hook",
|
||||
"warn_if_has_hooks",
|
||||
|
||||
@ -12,6 +12,37 @@ These APIs are experimental and subject to change without notice.
|
||||
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
|
||||
```
|
||||
|
||||
## torch.fx.experimental.sym_node
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
is_channels_last_contiguous_2d
|
||||
is_channels_last_contiguous_3d
|
||||
is_channels_last_strides_2d
|
||||
is_channels_last_strides_3d
|
||||
is_contiguous
|
||||
is_non_overlapping_and_dense_indicator
|
||||
method_to_operator
|
||||
sympy_is_channels_last_contiguous_2d
|
||||
sympy_is_channels_last_contiguous_3d
|
||||
sympy_is_channels_last_strides_2d
|
||||
sympy_is_channels_last_strides_3d
|
||||
sympy_is_channels_last_strides_generic
|
||||
sympy_is_contiguous
|
||||
sympy_is_contiguous_generic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.symbolic_shapes
|
||||
|
||||
```{eval-rst}
|
||||
@ -69,6 +100,25 @@ These APIs are experimental and subject to change without notice.
|
||||
rebind_unbacked
|
||||
resolve_unbacked_bindings
|
||||
is_accessor_node
|
||||
cast_symbool_to_symint_guardless
|
||||
create_contiguous
|
||||
error
|
||||
eval_guards
|
||||
eval_is_non_overlapping_and_dense
|
||||
find_symbol_binding_fx_nodes
|
||||
free_symbols
|
||||
free_unbacked_symbols
|
||||
fx_placeholder_targets
|
||||
fx_placeholder_vals
|
||||
guard_bool
|
||||
guard_float
|
||||
guard_int
|
||||
guard_scalar
|
||||
has_hint
|
||||
has_symbolic_sizes_strides
|
||||
is_nested_int
|
||||
is_symbol_binding_fx_node
|
||||
is_symbolic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.proxy_tensor
|
||||
@ -91,4 +141,46 @@ These APIs are experimental and subject to change without notice.
|
||||
get_proxy_mode
|
||||
maybe_enable_thunkify
|
||||
maybe_disable_thunkify
|
||||
decompose
|
||||
disable_autocast_cache
|
||||
disable_proxy_modes_tracing
|
||||
extract_val
|
||||
fake_signature
|
||||
fetch_object_proxy
|
||||
fetch_sym_proxy
|
||||
has_proxy_slot
|
||||
is_sym_node
|
||||
maybe_handle_decomp
|
||||
proxy_call
|
||||
set_meta
|
||||
set_original_aten_op
|
||||
set_proxy_slot
|
||||
snapshot_fake
|
||||
```
|
||||
|
||||
## torch.fx.experimental.unification.unification_tools
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
assoc
|
||||
assoc_in
|
||||
dissoc
|
||||
first
|
||||
keyfilter
|
||||
keymap
|
||||
merge
|
||||
merge_with
|
||||
update_in
|
||||
valfilter
|
||||
valmap
|
||||
|
||||
@ -1134,7 +1134,6 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.refinement_types
|
||||
.. py:module:: torch.fx.experimental.rewriter
|
||||
.. py:module:: torch.fx.experimental.schema_type_annotation
|
||||
.. py:module:: torch.fx.experimental.sym_node
|
||||
.. py:module:: torch.fx.experimental.unification.core
|
||||
.. py:module:: torch.fx.experimental.unification.dispatch
|
||||
.. py:module:: torch.fx.experimental.unification.match
|
||||
@ -1144,7 +1143,6 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
|
||||
.. py:module:: torch.fx.experimental.unification.unification_tools
|
||||
.. py:module:: torch.fx.experimental.unification.utils
|
||||
.. py:module:: torch.fx.experimental.unification.variable
|
||||
.. py:module:: torch.fx.experimental.unify_refinements
|
||||
|
||||
@ -134,6 +134,23 @@ Quantization to work with this as well.
|
||||
ObservationType
|
||||
```
|
||||
|
||||
## torch.ao.quantization.backend_config.utils
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.backend_config.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
entry_to_pretty_str
|
||||
pattern_to_human_readable
|
||||
remove_boolean_dispatch_from_name
|
||||
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.custom_config
|
||||
|
||||
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
|
||||
@ -154,6 +171,30 @@ This module contains a few CustomConfig classes that's used in both eager mode a
|
||||
StandaloneModuleConfigEntry
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.utils
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.fx.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
all_node_args_except_first
|
||||
all_node_args_have_no_tensors
|
||||
collect_producer_nodes
|
||||
create_getattr_from_value
|
||||
create_node_from_old_node_preserve_meta
|
||||
graph_module_from_producer_nodes
|
||||
maybe_get_next_module
|
||||
node_arg_is_bias
|
||||
node_arg_is_weight
|
||||
return_arg_list
|
||||
```
|
||||
|
||||
## torch.ao.quantization.quantizer
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
@ -19,6 +19,91 @@
|
||||
swap_tensors
|
||||
```
|
||||
|
||||
# torch.utils.collect_env
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
check_release_file
|
||||
is_xnnpack_available
|
||||
pretty_str
|
||||
```
|
||||
|
||||
# torch.utils.flop_counter
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
baddbmm_flop
|
||||
bmm_flop
|
||||
conv_backward_flop
|
||||
conv_flop
|
||||
conv_flop_count
|
||||
register_flop_formula
|
||||
sdpa_backward_flop
|
||||
sdpa_backward_flop_count
|
||||
sdpa_flop
|
||||
sdpa_flop_count
|
||||
shape_wrapper
|
||||
```
|
||||
|
||||
# torch.utils.hipify.hipify_python
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
compute_stats
|
||||
extract_arguments
|
||||
file_add_header
|
||||
file_specific_replacement
|
||||
find_bracket_group
|
||||
find_closure_group
|
||||
find_parentheses_group
|
||||
fix_static_global_kernels
|
||||
hip_header_magic
|
||||
hipify
|
||||
is_caffe2_gpu_file
|
||||
is_cusparse_file
|
||||
is_out_of_place
|
||||
is_pytorch_file
|
||||
is_special_file
|
||||
openf
|
||||
preprocess_file_and_save_result
|
||||
preprocessor
|
||||
processKernelLaunches
|
||||
replace_extern_shared
|
||||
replace_math_functions
|
||||
str2bool
|
||||
```
|
||||
|
||||
|
||||
<!-- This module needs to be documented. Adding here in the meantime
|
||||
for tracking purposes -->
|
||||
```{eval-rst}
|
||||
@ -43,7 +128,6 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
|
||||
.. py:module:: torch.utils.bundled_inputs
|
||||
.. py:module:: torch.utils.checkpoint
|
||||
.. py:module:: torch.utils.collect_env
|
||||
.. py:module:: torch.utils.cpp_backtrace
|
||||
.. py:module:: torch.utils.cpp_extension
|
||||
.. py:module:: torch.utils.data.backward_compatibility
|
||||
@ -80,10 +164,8 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.data.sampler
|
||||
.. py:module:: torch.utils.dlpack
|
||||
.. py:module:: torch.utils.file_baton
|
||||
.. py:module:: torch.utils.flop_counter
|
||||
.. py:module:: torch.utils.hipify.constants
|
||||
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
|
||||
.. py:module:: torch.utils.hipify.hipify_python
|
||||
.. py:module:: torch.utils.hipify.version
|
||||
.. py:module:: torch.utils.hooks
|
||||
.. py:module:: torch.utils.jit.log_extract
|
||||
|
||||
@ -184,7 +184,6 @@ ignore = [
|
||||
"TC006",
|
||||
# TODO: Remove Python-3.10 specific suppressions
|
||||
"B905",
|
||||
"UP035",
|
||||
]
|
||||
select = [
|
||||
"B",
|
||||
|
||||
33
setup.py
33
setup.py
@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1615,6 +1646,7 @@ def main() -> None:
|
||||
mirror_files_into_torchgen()
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
@ -1649,6 +1681,7 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from backend import get_custom_backend_library_path, Model, to_custom_backend
|
||||
@ -41,14 +40,11 @@ class TestCustomBackend(TestCase):
|
||||
self.test_execute()
|
||||
|
||||
# Save and load.
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.close()
|
||||
torch.jit.save(self.model, f.name)
|
||||
loaded = torch.jit.load(f.name)
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
self.model = loaded
|
||||
self.model = loaded
|
||||
|
||||
# Test execution again.
|
||||
self.test_execute()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -144,16 +143,13 @@ def forward(self, arg0_1):
|
||||
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
||||
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
||||
# close the file after creation and try to remove it manually.
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
file.close()
|
||||
model.save(file.name)
|
||||
loaded = torch.jit.load(file.name)
|
||||
finally:
|
||||
os.unlink(file.name)
|
||||
|
||||
output = loaded.forward(torch.ones(5))
|
||||
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
||||
output = loaded.forward(torch.ones(5))
|
||||
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -204,14 +204,16 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
self.assertTrue(b_dt.grad is not None)
|
||||
self.assertTrue(x_dt.grad is None)
|
||||
|
||||
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _run_single_arg_fwd(
|
||||
self, model, arg, placements=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
|
||||
device_mesh = self.build_device_mesh()
|
||||
model_copy = copy.deepcopy(model).to(device=self.device_type)
|
||||
dist_model = distribute_module(model, device_mesh, _conv_fn)
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, placements)
|
||||
out_dt = dist_model(arg_dt.to(device=self.device_type))
|
||||
out = model_copy(arg)
|
||||
out = model_copy(arg_dt.full_tensor())
|
||||
return (out_dt.full_tensor(), out)
|
||||
|
||||
@with_comms
|
||||
@ -219,22 +221,20 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
model = nn.Conv1d(64, 64, 3, padding=1)
|
||||
x = torch.randn(1, 64, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
self.assertEqual(out_dt, out)
|
||||
|
||||
@with_comms
|
||||
def test_conv3d(self):
|
||||
model = nn.Conv3d(64, 64, 3, padding=1)
|
||||
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
|
||||
self.assertEqual(out_dt, out)
|
||||
|
||||
|
||||
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistConvolutionOpsTest,
|
||||
# Send / recv ops are not supported
|
||||
skipped_tests=[
|
||||
"test_conv1d",
|
||||
"test_conv3d",
|
||||
"test_conv_backward_none_grad_inp",
|
||||
"test_depthwise_convolution",
|
||||
"test_downsampling_convolution",
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
@ -22,9 +21,8 @@ from torch.distributed.tensor import (
|
||||
)
|
||||
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -35,7 +33,11 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
generate_shard_orders,
|
||||
make_full_tensor,
|
||||
map_local_tensor_for_rank,
|
||||
patched_distribute_tensor as _distribute_tensor,
|
||||
redistribute,
|
||||
with_comms,
|
||||
)
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
@ -785,88 +787,6 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
else:
|
||||
return ""
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def redistribute(
|
||||
self,
|
||||
dtensor_input,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""
|
||||
wrapper function to support shard_order for redistribution
|
||||
This is a simpler version of Redistribute, only considers the forward.
|
||||
"""
|
||||
if placements is None:
|
||||
placements = self._shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
old_spec = dtensor_input._spec
|
||||
new_spec = copy.deepcopy(old_spec)
|
||||
new_spec.placements = placements
|
||||
if shard_order is not None:
|
||||
new_spec.shard_order = shard_order
|
||||
else:
|
||||
new_spec.shard_order = ()
|
||||
if old_spec == new_spec:
|
||||
return dtensor_input
|
||||
dtensor_input = DTensor.from_local(
|
||||
redistribute_local_tensor(
|
||||
dtensor_input.to_local(),
|
||||
old_spec,
|
||||
new_spec,
|
||||
use_graph_based_transform=use_graph_based_transform,
|
||||
),
|
||||
device_mesh,
|
||||
)
|
||||
dtensor_input._spec = copy.deepcopy(new_spec)
|
||||
return dtensor_input # returns DTensor
|
||||
|
||||
# TODO(zpcore): remove once the native distribute_tensor supports
|
||||
# shard_order arg
|
||||
def distribute_tensor(
|
||||
self,
|
||||
input_tensor,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""wrapper function to support shard_order for tensor distribution"""
|
||||
if placements is None:
|
||||
placements = self._shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
|
||||
# fix the shard order
|
||||
return self.redistribute(
|
||||
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
|
||||
)
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def full_tensor(self, dtensor_input):
|
||||
"""wrapper function to support DTensor.full_tensor"""
|
||||
return self.redistribute(
|
||||
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
|
||||
).to_local()
|
||||
|
||||
def _shard_order_to_placement(self, shard_order, mesh):
|
||||
"""convert shard_order to placement with only Replicate() and Shard()"""
|
||||
placements = [Replicate() for _ in range(mesh.ndim)]
|
||||
if shard_order is not None:
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
for mesh_dim in mesh_dims:
|
||||
placements[mesh_dim] = Shard(tensor_dim)
|
||||
return tuple(placements)
|
||||
|
||||
def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
|
||||
"""Convert shard_order dict to ShardOrder"""
|
||||
return tuple(
|
||||
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
|
||||
for tensor_dim, mesh_dims in shard_order.items()
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_ordered_redistribute(self):
|
||||
"""Test ordered redistribution with various sharding syntaxes"""
|
||||
@ -927,13 +847,11 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
|
||||
sharding_src_dst_pairs_with_expected_trace
|
||||
):
|
||||
sharded_dt = self.distribute_tensor(
|
||||
sharded_dt = _distribute_tensor(
|
||||
input_data.clone(), mesh, src_placement, shard_order=src_order
|
||||
)
|
||||
with DebugMode(record_torchfunction=False) as debug_mode:
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt, mesh, dst_placement, dst_order
|
||||
)
|
||||
sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order)
|
||||
trace_str = self._extract_redistribute_trace_from_debug_mode(
|
||||
debug_mode.debug_string()
|
||||
)
|
||||
@ -957,49 +875,11 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
trace_str,
|
||||
"""S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
|
||||
)
|
||||
expected_dt = self.distribute_tensor(
|
||||
expected_dt = _distribute_tensor(
|
||||
input_data.clone(), mesh, dst_placement, shard_order=dst_order
|
||||
)
|
||||
self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
|
||||
|
||||
def generate_shard_orders(self, mesh, tensor_rank):
|
||||
# Generate all possible sharding placement of tensor with rank
|
||||
# `tensor_rank` over mesh.
|
||||
def _split_list(lst: list, N: int):
|
||||
def compositions(n, k):
|
||||
if k == 1:
|
||||
yield [n]
|
||||
else:
|
||||
for i in range(1, n - k + 2):
|
||||
for tail in compositions(n - i, k - 1):
|
||||
yield [i] + tail
|
||||
|
||||
length = len(lst)
|
||||
for comp in compositions(length, N):
|
||||
result = []
|
||||
start = 0
|
||||
for size in comp:
|
||||
result.append(lst[start : start + size])
|
||||
start += size
|
||||
yield result
|
||||
|
||||
all_mesh = list(range(mesh.ndim))
|
||||
all_device_order = list(itertools.permutations(all_mesh))
|
||||
for device_order in all_device_order:
|
||||
# split on device orders, and assign each device order segment to a tensor dim
|
||||
for num_split in range(1, mesh.ndim + 1):
|
||||
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
|
||||
for tensor_dims in itertools.combinations(
|
||||
range(tensor_rank), len(splitted_list)
|
||||
):
|
||||
shard_order = {}
|
||||
assert len(tensor_dims) == len(splitted_list)
|
||||
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
|
||||
shard_order[tensor_dim] = device_order[
|
||||
mesh_dims[0] : mesh_dims[-1] + 1
|
||||
]
|
||||
yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
|
||||
|
||||
@with_comms
|
||||
def test_generate_shard_orders(self):
|
||||
"""Check if `generate_shard_orders` generates unique sharding combinations"""
|
||||
@ -1012,7 +892,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
]
|
||||
for test_input in test_inputs:
|
||||
all_combinations = []
|
||||
for shard_order in self.generate_shard_orders(
|
||||
for shard_order in generate_shard_orders(
|
||||
test_input["mesh"], test_input["tensor_rank"]
|
||||
):
|
||||
all_combinations.append(shard_order) # noqa: PERF402
|
||||
@ -1062,12 +942,12 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
input_data = torch.randn(tensor_shape, device=self.device_type)
|
||||
tensor_rank = input_data.ndim
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
shard_orders = generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = self.distribute_tensor(
|
||||
sharded_dt = _distribute_tensor(
|
||||
input_data.clone(), mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(self.full_tensor(sharded_dt), input_data)
|
||||
self.assertEqual(make_full_tensor(sharded_dt), input_data)
|
||||
|
||||
# 2. Verify the correctness of redistribution from DTensor to DTensor.
|
||||
# This test repeatedly redistributes a DTensor to various ordered
|
||||
@ -1078,20 +958,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
tensor_rank = input_data.ndim
|
||||
prev_sharded_dt = None
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
shard_orders = generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
if prev_sharded_dt is None:
|
||||
prev_sharded_dt = self.distribute_tensor(
|
||||
prev_sharded_dt = _distribute_tensor(
|
||||
input_data.clone(),
|
||||
mesh,
|
||||
placements=None,
|
||||
shard_order=shard_order,
|
||||
)
|
||||
else:
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt = redistribute(
|
||||
prev_sharded_dt, mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(self.full_tensor(sharded_dt), input_data)
|
||||
self.assertEqual(make_full_tensor(sharded_dt), input_data)
|
||||
prev_sharded_dt = sharded_dt
|
||||
|
||||
@with_comms
|
||||
@ -1136,13 +1016,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
local_tensor = torch.randn(shape, device=self.device_type)
|
||||
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = self.generate_shard_orders(mesh, len(shape))
|
||||
shard_orders = generate_shard_orders(mesh, len(shape))
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt = redistribute(
|
||||
full_tensor, mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(
|
||||
self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
|
||||
make_full_tensor(sharded_dt), make_full_tensor(full_tensor)
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
@ -1152,24 +1032,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_ordered_redistribute_for_special_placement(self):
|
||||
"""Test ordered redistribution with special placement"""
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
|
||||
torch.manual_seed(21)
|
||||
mesh = init_device_mesh(self.device_type, (8,))
|
||||
input_data = torch.randn((8, 8), device=self.device_type)
|
||||
src_placement = [Shard(1)]
|
||||
tgt_placement = [
|
||||
(_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
|
||||
(MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
|
||||
]
|
||||
sharded_dt = self.distribute_tensor(
|
||||
sharded_dt = _distribute_tensor(
|
||||
input_data.clone(),
|
||||
mesh,
|
||||
src_placement,
|
||||
shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
|
||||
)
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt, mesh, tgt_placement, shard_order=None
|
||||
)
|
||||
sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None)
|
||||
|
||||
@with_comms
|
||||
def test_shard_order_same_data_as_strided_shard(self):
|
||||
@ -1179,7 +1055,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
|
||||
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
|
||||
# specify right-to-left order use ordered shard
|
||||
x_ordered_dt = self.distribute_tensor(
|
||||
x_ordered_dt = _distribute_tensor(
|
||||
x,
|
||||
device_mesh,
|
||||
placements=[Shard(0), Shard(0)],
|
||||
|
||||
@ -34,6 +34,10 @@ from torch.distributed.tensor.placement_types import (
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
generate_shard_orders,
|
||||
LocalDTensorTestBase,
|
||||
patched_distribute_tensor as _distribute_tensor,
|
||||
shard_order_to_placement,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
@ -774,6 +778,63 @@ class TestStridedSharding(DTensorTestBase):
|
||||
self.assertEqual(dtensor.full_tensor(), tensor)
|
||||
|
||||
|
||||
class Test_StridedShard_with_shard_order(LocalDTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 32
|
||||
|
||||
@with_comms
|
||||
def test_StridedShard_to_shard_order(self):
|
||||
with LocalTensorMode(ranks=self.world_size):
|
||||
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2))
|
||||
shard_iter = generate_shard_orders(mesh, 3)
|
||||
# It takes ~4.8h to complete total 2520 shard order combinations here
|
||||
# using LocalTensor. So we only randomly pick 25 shard orders to test.
|
||||
all_shard_order = list(shard_iter)
|
||||
import random
|
||||
|
||||
random.seed(42)
|
||||
shard_order_choices = random.sample(
|
||||
all_shard_order, min(25, len(all_shard_order))
|
||||
)
|
||||
|
||||
x = torch.randn(32, 32, 32)
|
||||
for shard_order in shard_order_choices:
|
||||
a = _distribute_tensor(x, mesh, None, shard_order)
|
||||
|
||||
placement_without_stridedshard = shard_order_to_placement(
|
||||
shard_order, mesh
|
||||
)
|
||||
placements_with_stridedshard = (
|
||||
DTensorSpec._convert_shard_order_to_StridedShard(
|
||||
shard_order, placement_without_stridedshard, mesh
|
||||
)
|
||||
)
|
||||
b = distribute_tensor(x, mesh, placements_with_stridedshard)
|
||||
shard_order_from_stridedshard = (
|
||||
DTensorSpec._maybe_convert_StridedShard_to_shard_order(
|
||||
placements_with_stridedshard, mesh
|
||||
)
|
||||
)
|
||||
self.assertEqual(shard_order, shard_order_from_stridedshard)
|
||||
self.assertEqual(a.to_local(), b.to_local())
|
||||
|
||||
@with_comms
|
||||
def test_StridedShard_not_convertible_to_shard_order(self):
|
||||
with LocalTensorMode(ranks=self.world_size):
|
||||
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8))
|
||||
unconvertible_placements_list = [
|
||||
[_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)],
|
||||
[_StridedShard(0, split_factor=2), Shard(1)],
|
||||
[_StridedShard(1, split_factor=16), Shard(1)],
|
||||
]
|
||||
for placements in unconvertible_placements_list:
|
||||
shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order(
|
||||
tuple(placements), mesh
|
||||
)
|
||||
self.assertIsNone(shard_order)
|
||||
|
||||
|
||||
class Test2DStridedLocalShard(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
||||
@ -861,7 +861,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||
def test_logs_out(self):
|
||||
import tempfile
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
with tempfile.NamedTemporaryFile(delete=True) as tmp:
|
||||
file_path = _as_posix_path(tmp.name)
|
||||
"""
|
||||
NamedTemporaryFile will include a file open operation.
|
||||
@ -888,10 +888,6 @@ fn(torch.randn(5))
|
||||
file_path, encoding="utf-8"
|
||||
) as fd: # encoding file to UTF-8 for Windows.
|
||||
lines = fd.read()
|
||||
fd.close()
|
||||
os.remove(
|
||||
file_path
|
||||
) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False.
|
||||
orig_maxDiff = unittest.TestCase.maxDiff
|
||||
unittest.TestCase.maxDiff = None
|
||||
try:
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@ -97,55 +96,55 @@ def run_with_nativert(ep):
|
||||
MODEL_NAME = "forward"
|
||||
|
||||
# TODO Does named tempfile have collision?
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
torch.export.pt2_archive._package.package_pt2(
|
||||
f, exported_programs={MODEL_NAME: ep_infer}
|
||||
)
|
||||
filename = f.name
|
||||
|
||||
try:
|
||||
ep_args, ep_kwargs = ep_infer.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
flat_expected = pytree.tree_leaves(
|
||||
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
ep_args, ep_kwargs = ep_infer.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
flat_expected = pytree.tree_leaves(
|
||||
ep_infer.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
|
||||
model_runner = PyModelRunner(filename, MODEL_NAME)
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
else:
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
model_runner = PyModelRunner(filename, MODEL_NAME)
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
else:
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
finally:
|
||||
pathlib.Path(filename).unlink(missing_ok=True)
|
||||
return ep
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(
|
||||
expected, torch.Tensor
|
||||
):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
else:
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
return ep
|
||||
|
||||
|
||||
def mocked_nativert_export_strict(*args, **kwargs):
|
||||
@ -287,7 +286,7 @@ class TestNativeRT(TestCase):
|
||||
)
|
||||
|
||||
# package everything needed for the NativeRT to execute the AOTI delegate
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
package_nativert_with_aoti_delegate(
|
||||
f,
|
||||
MODEL_NAME,
|
||||
@ -298,50 +297,48 @@ class TestNativeRT(TestCase):
|
||||
)
|
||||
filename = f.name
|
||||
|
||||
try:
|
||||
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
flat_expected = pytree.tree_leaves(
|
||||
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
ep_args, ep_kwargs = aoti_delegate_ep.example_inputs
|
||||
ep_args_copied, ep_kwargs_copied = (
|
||||
copy.deepcopy(ep_args),
|
||||
copy.deepcopy(ep_kwargs),
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
torch.manual_seed(0)
|
||||
try:
|
||||
flat_expected = pytree.tree_leaves(
|
||||
aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied)
|
||||
)
|
||||
except Exception as e:
|
||||
raise unittest.case.SkipTest(str(e)) from e
|
||||
|
||||
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
else:
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(
|
||||
expected, torch.Tensor
|
||||
):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}")
|
||||
torch.manual_seed(0)
|
||||
if _is_supported_types((ep_args, ep_kwargs)):
|
||||
results = model_runner.run(*ep_args, **ep_kwargs)
|
||||
else:
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
finally:
|
||||
pathlib.Path(filename).unlink(missing_ok=True)
|
||||
results = model_runner.run_with_flat_inputs_and_outputs(
|
||||
*pytree.tree_leaves((ep_args, ep_kwargs))
|
||||
)
|
||||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(
|
||||
expected, torch.Tensor
|
||||
):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
assert result.device == expected.device
|
||||
torch.testing.assert_close(result, expected, equal_nan=True)
|
||||
else:
|
||||
assert result == expected
|
||||
except RuntimeError as e:
|
||||
# User need to register pytree type on the cpp side, which
|
||||
# cannot be tested in python unittest.
|
||||
if "Unknown pytree node type" in str(e):
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
if is_fbcode():
|
||||
|
||||
@ -206,6 +206,10 @@ class TestPyCodeCache(TestCase):
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
|
||||
if torch.xpu.is_available():
|
||||
wrapper_path = wrapper_path.splitlines()[-1]
|
||||
hit = hit.splitlines()[-1]
|
||||
self.assertEqual(hit, "1")
|
||||
|
||||
with open(wrapper_path) as f:
|
||||
|
||||
154
test/inductor/test_cutedsl_grouped_mm.py
Normal file
154
test/inductor/test_cutedsl_grouped_mm.py
Normal file
@ -0,0 +1,154 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import ensure_cute_available
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
|
||||
"CuTeDSL library or Blackwell device not available",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestCuTeDSLGroupedGemm(InductorTestCase):
|
||||
def _get_inputs(
|
||||
self,
|
||||
group_size: int,
|
||||
M_hint: int,
|
||||
K: int,
|
||||
N: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
alignment: int = 16,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# --- Random, tile-aligned M sizes ---
|
||||
M_sizes = (
|
||||
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
|
||||
* alignment
|
||||
)
|
||||
|
||||
M_total = torch.sum(M_sizes).item()
|
||||
|
||||
# --- Construct input tensors ---
|
||||
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
|
||||
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
# --- Build offsets (no leading zero, strictly increasing) ---
|
||||
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
|
||||
|
||||
return (A, B, offsets)
|
||||
|
||||
@parametrize("group_size", (2, 8))
|
||||
@parametrize("M_hint", (256, 1024))
|
||||
@parametrize("K", (64, 128))
|
||||
@parametrize("N", (128, 256))
|
||||
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# Eager execution
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# Test with Cute backend
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
|
||||
@parametrize("layout_B", ("contiguous", "broadcasted"))
|
||||
def test_grouped_gemm_assorted_layouts(
|
||||
self,
|
||||
layout_A: str,
|
||||
layout_B: str,
|
||||
):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
G, K, N = 8, 64, 128
|
||||
M_sizes = [128] * G
|
||||
sum_M = sum(M_sizes)
|
||||
offsets = torch.tensor(
|
||||
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
|
||||
A = A_base
|
||||
|
||||
if layout_A == "offset":
|
||||
# allocate bigger buffer than needed, use nonzero storage offset
|
||||
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
|
||||
offset = 128 # skip first 128 elements
|
||||
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
|
||||
elif layout_A == "padded":
|
||||
# simulate row pitch > K (row_stride = K + pad)
|
||||
row_pitch = K + 8
|
||||
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
|
||||
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
|
||||
elif layout_A == "view":
|
||||
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
|
||||
A = A_storage.view(sum_M, K)
|
||||
assert A._base is not None
|
||||
assert A.shape == (sum_M, K)
|
||||
|
||||
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
if layout_B == "broadcasted":
|
||||
# Broadcast B across groups (zero stride along G)
|
||||
B = B[0].expand(G, K, N)
|
||||
assert B.stride(0) == 0
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# --- eager ---
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# --- compiled (CUTE backend) ---
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -3,7 +3,8 @@
|
||||
import functools
|
||||
import weakref
|
||||
from collections import Counter
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_passes.memory_estimator import (
|
||||
@ -28,7 +29,7 @@ def device_filter(device):
|
||||
|
||||
|
||||
class FakeTensorMemoryProfilerMode(TorchDispatchMode):
|
||||
def __init__(self, device_filter: Optional[Callable[torch.device, bool]] = None):
|
||||
def __init__(self, device_filter: Optional[Callable[[torch.device], bool]] = None):
|
||||
# counter of storage ids to live references
|
||||
self.storage_count: dict[int, int] = Counter()
|
||||
# live fake tensors
|
||||
|
||||
@ -482,8 +482,8 @@ class TestExecutionTrace(TestCase):
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
|
||||
@unittest.skipIf(
|
||||
(not has_triton()) or (not TEST_CUDA and not TEST_XPU),
|
||||
"need triton and device(CUDA or XPU) availability to run",
|
||||
(not has_triton()) or (not TEST_CUDA),
|
||||
"need triton and device CUDA availability to run",
|
||||
)
|
||||
@skipCPUIf(True, "skip CPU device for testing profiling triton")
|
||||
def test_triton_fx_graph_with_et(self, device):
|
||||
|
||||
@ -2005,6 +2005,10 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
||||
report = json.load(f)
|
||||
self._validate_basic_json(report["traceEvents"], with_cuda)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch.xpu.is_available(),
|
||||
"XPU Trace event ends too late! Refer https://github.com/intel/torch-xpu-ops/issues/2263",
|
||||
)
|
||||
@unittest.skipIf(not kineto_available(), "Kineto is required")
|
||||
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
||||
def test_basic_chrome_trace(self):
|
||||
@ -2158,7 +2162,10 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
||||
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
||||
def test_basic_profile(self):
|
||||
# test a really basic profile to make sure no erroneous aten ops are run
|
||||
x = torch.randn(4, device="cuda")
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
self.assertIsNotNone(acc)
|
||||
device = acc.type
|
||||
x = torch.randn(4, device=device)
|
||||
with torch.profiler.profile(with_stack=True) as p:
|
||||
x *= 2
|
||||
names = [e.name for e in p.events()]
|
||||
@ -2225,6 +2232,7 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available(), "CUDA complains about forking after init"
|
||||
)
|
||||
@unittest.skipIf(torch.xpu.is_available(), "XPU complains about forking after init")
|
||||
@unittest.skipIf(IS_WINDOWS, "can't use os.fork() on Windows")
|
||||
def test_forked_process(self):
|
||||
# Induce a pid cache by running the profiler with payload
|
||||
|
||||
@ -263,13 +263,7 @@ S390X_BLOCKLIST = [
|
||||
|
||||
XPU_BLOCKLIST = [
|
||||
"test_autograd",
|
||||
"profiler/test_cpp_thread",
|
||||
"profiler/test_execution_trace",
|
||||
"profiler/test_memory_profiler",
|
||||
"profiler/test_profiler",
|
||||
"profiler/test_profiler_tree",
|
||||
"profiler/test_record_function",
|
||||
"profiler/test_torch_tidy",
|
||||
"test_openreg",
|
||||
]
|
||||
|
||||
|
||||
@ -1,245 +1,236 @@
|
||||
{
|
||||
"EndToEndLSTM (__main__.RNNTest)": 207.89400227864584,
|
||||
"MultiheadAttention (__main__.ModulesTest)": 141.1396687825521,
|
||||
"test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 214.02366638183594,
|
||||
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 77.26125049591064,
|
||||
"test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 116.37000020345052,
|
||||
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 69.25722334120009,
|
||||
"test_after_aot_gpu_runtime_error (__main__.MinifierIsolateTests)": 65.84466807047527,
|
||||
"test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 178.41399637858072,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.55014337812151,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 122.18047623407273,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 192.6405719575428,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.27904801141648,
|
||||
"test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.906999588012695,
|
||||
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 62.244998931884766,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 150.04100036621094,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 191.85050201416016,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.9276631673177,
|
||||
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.31450271606445,
|
||||
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 125.24066416422527,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.47783279418945,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.46250025431316,
|
||||
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1031.0534973144531,
|
||||
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 239.67400105794272,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 495.0447726779514,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 490.18524169921875,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 144.06477737426758,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 342.20416259765625,
|
||||
"test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 62.01366678873698,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 71.07200050354004,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 73.9221674601237,
|
||||
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 226.0122528076172,
|
||||
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 144.97249857584634,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 303.20537185668945,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 386.0518798828125,
|
||||
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 291.2442270914714,
|
||||
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 95.87866719563802,
|
||||
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 98.38716634114583,
|
||||
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 69.08016649881999,
|
||||
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 69.88233311971028,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 104.17599995930989,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 97.41800308227539,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 474.6719970703125,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 440.4375,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 293.3983332316081,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 238.7328338623047,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1218.4906717936199,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.73516782124837,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1156.0123494466145,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.13916714986165,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.90450032552083,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.42100016276042,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.98883310953777,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 73.34433364868164,
|
||||
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.38016573588053,
|
||||
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.52783330281575,
|
||||
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 111.06333287556966,
|
||||
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 110.19833374023438,
|
||||
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.10083134969075,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.23766644795736,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 70.18666712443034,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 62.61399841308594,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 67.7816670735677,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 121.6183344523112,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 107.30266698201497,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 130.8143310546875,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 127.27633412679036,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 303.55183664957684,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 234.41216532389322,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 85.3436673482259,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 80.9688326517741,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 82.55149968465169,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.37966791788737,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 129.88233184814453,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 129.4015007019043,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1282.3826497395833,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1270.64599609375,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1297.9046630859375,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 545.2034962972006,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 572.5616760253906,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 64.40316645304362,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.68383344014485,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.48333422342936,
|
||||
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 61.959999084472656,
|
||||
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 105.79100036621094,
|
||||
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 122.34666570027669,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 68.7205015818278,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.2183329264323,
|
||||
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.86883227030437,
|
||||
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 77.48183314005534,
|
||||
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 79.1564998626709,
|
||||
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 160.41250228881836,
|
||||
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 79.10633341471355,
|
||||
"test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 60.106833140055336,
|
||||
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 221.3586196899414,
|
||||
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 504.3203754425049,
|
||||
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 78.03233337402344,
|
||||
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 152.302001953125,
|
||||
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 152.99433390299478,
|
||||
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 96.25399971008301,
|
||||
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 75.70275068283081,
|
||||
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 139.14399747674665,
|
||||
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 72.7847490310669,
|
||||
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 91.59966786702473,
|
||||
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 87.57833353678386,
|
||||
"test_count_nonzero_all (__main__.TestBool)": 664.9986343383789,
|
||||
"test_cp_flex_attention_document_mask (__main__.CPFlexAttentionTest)": 78.31500244140625,
|
||||
"test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 385.24249792099,
|
||||
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.70466740926106,
|
||||
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 685.0679931640625,
|
||||
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 86.26266733805339,
|
||||
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 292.93699645996094,
|
||||
"test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 66.84199905395508,
|
||||
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 69.56212568283081,
|
||||
"test_fail_creation_ops.py (__main__.TestTyping)": 69.80560022989908,
|
||||
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 73.36666552225749,
|
||||
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 90.40366744995117,
|
||||
"test_fuse_large_params_cpu (__main__.CpuTests)": 132.73199844360352,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 150.16662406921387,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 159.28499794006348,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 165.19283294677734,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 151.12366739908853,
|
||||
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.61699930826823,
|
||||
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.00600179036458,
|
||||
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.3759994506836,
|
||||
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 190.89249674479166,
|
||||
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 149.6598358154297,
|
||||
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 146.07766723632812,
|
||||
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 532.8139902750651,
|
||||
"test_graph_partition_refcount_cuda (__main__.GPUTests)": 69.78400001525878,
|
||||
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 267.04988850487604,
|
||||
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 273.54955800374347,
|
||||
"test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.84733072916666,
|
||||
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 326.0143330891927,
|
||||
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.96037435531616,
|
||||
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 77.44933319091797,
|
||||
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 126.81488884819879,
|
||||
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 118.70199839274089,
|
||||
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 129.20266723632812,
|
||||
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 97.18800099690755,
|
||||
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 130.3183339436849,
|
||||
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 140.43233235677084,
|
||||
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 293.122774971856,
|
||||
"test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 63.835832277933754,
|
||||
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.77049922943115,
|
||||
"test_lstm_cpu (__main__.TestMkldnnCPU)": 100.89649963378906,
|
||||
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 140.07424926757812,
|
||||
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 72.90299733479817,
|
||||
"test_max_autotune_addmm_search_space_EXHAUSTIVE_dynamic_True (__main__.TestMaxAutotuneSubproc)": 82.62433369954427,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 87.51499938964844,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_True_use_aoti_True (__main__.TestCKBackend)": 71.22416591644287,
|
||||
"test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 424.50966389973956,
|
||||
"test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 134.14600626627603,
|
||||
"test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 358.88099161783856,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.58866712782118,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 62.68674945831299,
|
||||
"test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 65.85794713936355,
|
||||
"test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTest)": 103.6923344930013,
|
||||
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTest)": 187.6953328450521,
|
||||
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 370.27442932128906,
|
||||
"test_proper_exit (__main__.TestDataLoader)": 227.83111148410373,
|
||||
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 227.1901126437717,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.52099990844727,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.50249862670898,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 92.52400207519531,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.75499725341797,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 107.40500259399414,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 83.80450057983398,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 107.46599833170573,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.65650177001953,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 83.4114990234375,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.47100067138672,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 108.55533345540364,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 89.23666381835938,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.13900375366211,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 100.14550018310547,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.33649826049805,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.08150100708008,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 97.59600067138672,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 104.82933553059895,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 114.43099721272786,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 110.40333302815755,
|
||||
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 567.2765197753906,
|
||||
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1032.5083312988281,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 852.7170003255209,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1361.954854329427,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 77.385498046875,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 265.0193354288737,
|
||||
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 115.31749725341797,
|
||||
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 245.27666727701822,
|
||||
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 71.75300216674805,
|
||||
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 141.8895009358724,
|
||||
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 71.15749994913737,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 90.59066772460938,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 173.73916625976562,
|
||||
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.65066655476888,
|
||||
"test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 99.21799850463867,
|
||||
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 90.86299896240234,
|
||||
"test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.57050196329753,
|
||||
"test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 69.65149958928426,
|
||||
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 78.13350168863933,
|
||||
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 76.85255601671007,
|
||||
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 333.04866282145184,
|
||||
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 146.96599833170572,
|
||||
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 160.4881100124783,
|
||||
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 124.10055626763238,
|
||||
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 117.38410907321506,
|
||||
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 710.2327779134115,
|
||||
"test_sort_stable_cpu (__main__.CpuTritonTests)": 1324.4399820963542,
|
||||
"test_sort_stable_cuda (__main__.GPUTests)": 76.83109970092774,
|
||||
"test_split_cumsum_cpu (__main__.CpuTritonTests)": 88.58433532714844,
|
||||
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 160.1271684964498,
|
||||
"test_tensor_split (__main__.TestVmapOperators)": 79.18955569393519,
|
||||
"test_terminate_handler_on_crash (__main__.TestTorch)": 111.30388899644215,
|
||||
"test_terminate_signal (__main__.ForkTest)": 132.3458870516883,
|
||||
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 132.2043343567186,
|
||||
"test_terminate_signal (__main__.SpawnTest)": 136.1005539894104,
|
||||
"test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 76.20899939537048,
|
||||
"test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 63.82099969046457,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 61.925000508626304,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 60.89849980672201,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 66.88233375549316,
|
||||
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.9854990641276,
|
||||
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.4044977823893,
|
||||
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 108.19166437784831,
|
||||
"test_unary_ops (__main__.TestTEFuserDynamic)": 96.32655514611139,
|
||||
"test_unary_ops (__main__.TestTEFuserStatic)": 105.33362591266632,
|
||||
"test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 97.8336664835612,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 82.86566925048828,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 68.26500002543132,
|
||||
"test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 97.1120007832845,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 88.24766794840495,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 65.41266759236653,
|
||||
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 74.75533294677734,
|
||||
"test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 73.52500089009602,
|
||||
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.85466639200847,
|
||||
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 98.39650090535481,
|
||||
"test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 61.39695285615467,
|
||||
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 77.88249842325847,
|
||||
"test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.0695006052653,
|
||||
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 81.86250114440918,
|
||||
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 98.63116455078125,
|
||||
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 94.85683314005534,
|
||||
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 173.00183614095053
|
||||
"EndToEndLSTM (__main__.RNNTest)": 190.48799641927084,
|
||||
"MultiheadAttention (__main__.ModulesTest)": 141.2663370768229,
|
||||
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 82.87333234151204,
|
||||
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 70.6538565499442,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 123.34033711751302,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 171.25450134277344,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 119.71899922688802,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.35733322870163,
|
||||
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.64533233642578,
|
||||
"test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.672952016194664,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 138.04000091552734,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 172.1344985961914,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 114.02050018310547,
|
||||
"test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.25642830984933,
|
||||
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.3350003560384,
|
||||
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 120.95249938964844,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.97774887084961,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.90774917602539,
|
||||
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1144.3935089111328,
|
||||
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 222.58500061035156,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.10033162434894,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 517.1875050862631,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 113.88125228881836,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 235.77350616455078,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.6155014038086,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.63325119018555,
|
||||
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 216.2968317667643,
|
||||
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 153.0915012359619,
|
||||
"test_cat_2k_args (__main__.TestTEFuserDynamic)": 108.80471753561869,
|
||||
"test_cat_2k_args (__main__.TestTEFuserStatic)": 102.20949847949669,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 311.7026621500651,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 395.0001729329427,
|
||||
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 348.6218566894531,
|
||||
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 98.71574974060059,
|
||||
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 97.68499946594238,
|
||||
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.0557508468628,
|
||||
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 65.86899948120117,
|
||||
"test_comprehensive_gradient_cuda_complex64 (__main__.TestDecompCUDA)": 97.15880012512207,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 103.20700073242188,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 102.74033610026042,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 460.4286702473958,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 435.62066650390625,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 287.3090057373047,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 265.1860008239746,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1235.7365112304688,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.20825004577637,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1281.2615051269531,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.90750026702881,
|
||||
"test_comprehensive_linalg_householder_product_cuda_complex64 (__main__.TestDecompCUDA)": 79.04633331298828,
|
||||
"test_comprehensive_linalg_lu_factor_ex_cuda_complex128 (__main__.TestDecompCUDA)": 68.10879821777344,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.43025207519531,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.94575023651123,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.93649864196777,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.46275043487549,
|
||||
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 64.10650062561035,
|
||||
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.03124904632568,
|
||||
"test_comprehensive_linalg_svd_cuda_float64 (__main__.TestDecompCUDA)": 64.32800025939942,
|
||||
"test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 96.41353665865384,
|
||||
"test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 100.17661388103778,
|
||||
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 110.95025062561035,
|
||||
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 108.06550025939941,
|
||||
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 104.24150085449219,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.453749656677246,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 61.739999771118164,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 69.96549987792969,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 113.65749931335449,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 106.57500076293945,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 117.54049682617188,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 116.19766489664714,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 272.48475646972656,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 248.12175369262695,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 79.66900062561035,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 81.52649879455566,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 79.29400062561035,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.40349960327148,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 128.42924880981445,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 125.03675079345703,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1264.9732360839844,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1250.7332458496094,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1255.0684814453125,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 574.4627532958984,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 581.7282485961914,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 65.052001953125,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 61.19200134277344,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 63.16874885559082,
|
||||
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 62.39250183105469,
|
||||
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 113.32574844360352,
|
||||
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 113.91499900817871,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 74.42549800872803,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 76.1560001373291,
|
||||
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.76750087738037,
|
||||
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 70.69724941253662,
|
||||
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 69.87625026702881,
|
||||
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 80.2542495727539,
|
||||
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 69.0419979095459,
|
||||
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 117.03342655726841,
|
||||
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 289.50213841029574,
|
||||
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 67.38800048828125,
|
||||
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 145.27399444580078,
|
||||
"test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 66.9245999654134,
|
||||
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 151.91099548339844,
|
||||
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 92.79549789428711,
|
||||
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.60149955749512,
|
||||
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 69.27724676392972,
|
||||
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 76.24971498761859,
|
||||
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 81.93449974060059,
|
||||
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.87700080871582,
|
||||
"test_count_nonzero_all (__main__.TestBool)": 631.2585144042969,
|
||||
"test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 61.042999267578125,
|
||||
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.49850082397461,
|
||||
"test_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 93.03299713134766,
|
||||
"test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 228.46711820714614,
|
||||
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 286.29998779296875,
|
||||
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 68.43842806134906,
|
||||
"test_fail_random.py (__main__.TestTyping)": 74.83523060725285,
|
||||
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 72.84900093078613,
|
||||
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 75.86675071716309,
|
||||
"test_fuse_large_params_cpu (__main__.CpuTests)": 151.4199981689453,
|
||||
"test_fuse_large_params_cuda (__main__.GPUTests)": 60.351999282836914,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 158.3622828892299,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 149.6796646118164,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.97800064086914,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 114.8385009765625,
|
||||
"test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 84.69736822027909,
|
||||
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.62700080871582,
|
||||
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 89.197998046875,
|
||||
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 96.46900177001953,
|
||||
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 187.83824920654297,
|
||||
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.49449920654297,
|
||||
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 124.90424919128418,
|
||||
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 518.4157485961914,
|
||||
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 304.6440022786458,
|
||||
"test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 143.82052836698645,
|
||||
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 77.4985705784389,
|
||||
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 76.06225109100342,
|
||||
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 138.9222858973912,
|
||||
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 120.62233225504558,
|
||||
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 148.1219940185547,
|
||||
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 109.34200286865234,
|
||||
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 119.36233266194661,
|
||||
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 127.95700073242188,
|
||||
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 61.64850175380707,
|
||||
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 105.3174296787807,
|
||||
"test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 585.9210001627604,
|
||||
"test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 504.3250020345052,
|
||||
"test_lstm_cpu (__main__.TestMkldnnCPU)": 86.21566645304362,
|
||||
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 129.277715410505,
|
||||
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 64.24800109863281,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 77.23899841308594,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_True (__main__.TestCKBackend)": 65.15649795532227,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 62.579833984375,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.6555004119873,
|
||||
"test_pattern_matcher_multi_user_cpu (__main__.CpuTritonTests)": 142.21566772460938,
|
||||
"test_proper_exit (__main__.TestDataLoader)": 267.74214717320035,
|
||||
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 266.6539971487863,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 101.97100067138672,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.3346659342448,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.50300216674805,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 104.61333465576172,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.41133371988933,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.37100219726562,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.30900065104167,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.61750030517578,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 79.33600234985352,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 101.2393315633138,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 103.18400192260742,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.4114990234375,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.52833302815755,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.72700119018555,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 100.61966705322266,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.2750015258789,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.17449951171875,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.96749877929688,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 106.44049835205078,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 101.7173334757487,
|
||||
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 531.5236612955729,
|
||||
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1077.4210205078125,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 812.0880126953125,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1347.9365234375,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 88.93533070882161,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 269.01949310302734,
|
||||
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.99799601236978,
|
||||
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 232.36275100708008,
|
||||
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 69.80400085449219,
|
||||
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 134.3415012359619,
|
||||
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.51749992370605,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 91.21066792805989,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 170.97775268554688,
|
||||
"test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.608266321818036,
|
||||
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.62575149536133,
|
||||
"test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 63.59499969482422,
|
||||
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 88.68299865722656,
|
||||
"test_rnn_decomp_module_nn_LSTM_train_mode_cuda_float32 (__main__.TestDecompCUDA)": 91.50320053100586,
|
||||
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 66.10774898529053,
|
||||
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 66.20533180236816,
|
||||
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 243.1092529296875,
|
||||
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 105.01200103759766,
|
||||
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.93685695103237,
|
||||
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 142.38899993896484,
|
||||
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 119.90166600545247,
|
||||
"test_sort_bool_cpu (__main__.CpuTritonTests)": 346.2856750488281,
|
||||
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 423.09974098205566,
|
||||
"test_sort_stable_cuda (__main__.GPUTests)": 117.61659927368164,
|
||||
"test_sort_transpose_cpu (__main__.CpuTritonTests)": 378.31200154622394,
|
||||
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 222.822007894516,
|
||||
"test_terminate_handler_on_crash (__main__.TestTorch)": 143.31728431156702,
|
||||
"test_terminate_signal (__main__.ForkTest)": 168.20485967184817,
|
||||
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 168.19242484867573,
|
||||
"test_terminate_signal (__main__.SpawnTest)": 172.16428443363733,
|
||||
"test_thnn_conv_strided_padded_dilated (__main__.TestConvolutionNN)": 93.30639710426331,
|
||||
"test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 163.89743041992188,
|
||||
"test_train_parity_with_activation_checkpointing (__main__.TestFullyShard1DTrainingCompose)": 60.47671399797712,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 63.39550018310547,
|
||||
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 173.53924942016602,
|
||||
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 175.3212537765503,
|
||||
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.20649909973145,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 99.9885025024414,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 71.64024829864502,
|
||||
"test_view_ops (__main__.TestViewOpsWithLocalTensor)": 73.45887422561646,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.75249862670898,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 61.858001708984375,
|
||||
"test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 65.11023766653878,
|
||||
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 66.35274982452393,
|
||||
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 61.196499824523926,
|
||||
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.75380906604585,
|
||||
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 73.64649868011475,
|
||||
"test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 75.09799966358003,
|
||||
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.51450157165527,
|
||||
"test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.21433276221866,
|
||||
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 73.20024871826172,
|
||||
"test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 88.1349983215332,
|
||||
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.89924907684326,
|
||||
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.32975196838379,
|
||||
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 120.09600067138672
|
||||
}
|
||||
@ -239,6 +239,12 @@ class TestAccelerator(TestCase):
|
||||
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
||||
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
||||
|
||||
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
|
||||
def test_get_memory_info(self):
|
||||
free_bytes, total_bytes = torch.accelerator.get_memory_info()
|
||||
self.assertGreaterEqual(free_bytes, 0)
|
||||
self.assertGreaterEqual(total_bytes, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -313,15 +313,17 @@ class SerializationMixin:
|
||||
def test_serialization_gzip(self):
|
||||
# Test serialization with gzip file
|
||||
b = self._test_serialization_data()
|
||||
f1 = tempfile.NamedTemporaryFile(delete=False)
|
||||
f2 = tempfile.NamedTemporaryFile(delete=False)
|
||||
torch.save(b, f1)
|
||||
with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
with tempfile.NamedTemporaryFile() as f1, tempfile.NamedTemporaryFile(delete=False) as f2:
|
||||
torch.save(b, f1)
|
||||
f1.seek(0)
|
||||
with gzip.open(f2.name, 'wb') as f_out:
|
||||
shutil.copyfileobj(f1, f_out)
|
||||
|
||||
with gzip.open(f2.name, 'rb') as f:
|
||||
c = torch.load(f)
|
||||
self._test_serialization_assert(b, c)
|
||||
with gzip.open(f2.name, 'rb') as f:
|
||||
c = torch.load(f)
|
||||
self._test_serialization_assert(b, c)
|
||||
f2.close()
|
||||
os.unlink(f2.name)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_DILL or HAS_DILL_AT_LEAST_0_3_1,
|
||||
@ -382,19 +384,19 @@ class SerializationMixin:
|
||||
def test_serialization_offset_gzip(self):
|
||||
a = torch.randn(5, 5)
|
||||
i = 41
|
||||
f1 = tempfile.NamedTemporaryFile(delete=False)
|
||||
f2 = tempfile.NamedTemporaryFile(delete=False)
|
||||
with open(f1.name, 'wb') as f:
|
||||
pickle.dump(i, f)
|
||||
torch.save(a, f)
|
||||
with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
with tempfile.NamedTemporaryFile() as f1:
|
||||
pickle.dump(i, f1)
|
||||
torch.save(a, f1)
|
||||
f1.seek(0)
|
||||
with gzip.open(f2.name, 'wb') as f_out:
|
||||
shutil.copyfileobj(f1, f_out)
|
||||
|
||||
with gzip.open(f2.name, 'rb') as f:
|
||||
j = pickle.load(f)
|
||||
b = torch.load(f)
|
||||
self.assertTrue(torch.equal(a, b))
|
||||
self.assertEqual(i, j)
|
||||
with gzip.open(f2.name, 'rb') as f:
|
||||
j = pickle.load(f)
|
||||
b = torch.load(f)
|
||||
self.assertTrue(torch.equal(a, b))
|
||||
self.assertEqual(i, j)
|
||||
|
||||
def _test_serialization_sparse(self, weights_only):
|
||||
def _test_serialization(conversion):
|
||||
|
||||
@ -3728,7 +3728,6 @@ class TestSparse(TestSparseBase):
|
||||
@coalescedonoff
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@dtypesIfMPS(*all_mps_types())
|
||||
@expectedFailureMPS
|
||||
@dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater and not TEST_WITH_ROCM else [],
|
||||
*[torch.bfloat16] if SM80OrLater and not TEST_WITH_ROCM else [],
|
||||
torch.complex64,
|
||||
@ -3825,9 +3824,9 @@ class TestSparse(TestSparseBase):
|
||||
def different_dtypes():
|
||||
a, i_a, v_a = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
|
||||
b, i_b, v_b = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
|
||||
r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32))
|
||||
r2 = torch.sparse.mm(a.to(torch.float32), a.to(torch.float16))
|
||||
|
||||
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
|
||||
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Float does not match mat2 dtype Half', different_dtypes)
|
||||
|
||||
def test_backward_noncontiguous():
|
||||
# Sparse.mm backward used to wrong with non-contiguous grads,
|
||||
|
||||
@ -206,7 +206,8 @@ if __name__ == "__main__":
|
||||
test_multi_process(model, input)
|
||||
print(torch.xpu.device_count())
|
||||
"""
|
||||
rc = check_output(test_script)
|
||||
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
|
||||
rc = check_output(test_script).splitlines()[-1]
|
||||
self.assertEqual(rc, str(torch.xpu.device_count()))
|
||||
|
||||
def test_streams(self):
|
||||
|
||||
@ -3,9 +3,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# ruff: noqa: F401,PYI054
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from types import EllipsisType
|
||||
from typing import Any, Callable, Literal, overload, TypeVar
|
||||
from typing import Any, Literal, overload, TypeVar
|
||||
|
||||
import torch
|
||||
from torch import (
|
||||
|
||||
@ -2491,6 +2491,7 @@ def _accelerator_emptyCache() -> None: ...
|
||||
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
|
||||
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
|
||||
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
|
||||
def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ...
|
||||
def _accelerator_setAllocatorSettings(env: str) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
||||
|
||||
@ -22,8 +22,8 @@ import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence, Sized
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
from torch._guards import Guard
|
||||
@ -163,7 +163,7 @@ class ContextWrappingVariable(VariableTracker):
|
||||
class GenericContextWrappingVariable(UserDefinedObjectVariable):
|
||||
# Some methods in ContextWrappingVariable assumes the arguments are
|
||||
# python constants. Which might not always be the case here.
|
||||
def __init__(self, cm_obj: ContextManager[Any], **kwargs: Any) -> None:
|
||||
def __init__(self, cm_obj: AbstractContextManager[Any], **kwargs: Any) -> None:
|
||||
assert cm_obj is not None
|
||||
super().__init__(
|
||||
value=cm_obj,
|
||||
|
||||
@ -18,7 +18,8 @@ import collections
|
||||
import inspect
|
||||
import operator
|
||||
import sys
|
||||
from typing import Any, Optional, Sequence, TYPE_CHECKING
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# fmt: off
|
||||
# This file was generated by AutoHeuristic. Do not modify it manually!
|
||||
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
AHContext,
|
||||
|
||||
@ -550,6 +550,10 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
|
||||
).upper() # type: ignore[assignment]
|
||||
|
||||
cutedsl_enable_autotuning: bool = (
|
||||
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
|
||||
)
|
||||
|
||||
# DEPRECATED. This setting is ignored.
|
||||
autotune_fallback_to_aten = False
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -12,6 +14,7 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
||||
from .. import config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import _IntLike, Layout, TensorBox
|
||||
from ..utils import load_template
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
from torch._inductor.runtime.triton_compat import tl
|
||||
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
@ -22,11 +24,13 @@ from ..utils import (
|
||||
get_num_sms,
|
||||
has_free_symbols,
|
||||
use_aten_gemm_kernels,
|
||||
use_blackwell_cutedsl_grouped_mm,
|
||||
use_triton_template,
|
||||
)
|
||||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
check_supported_striding,
|
||||
load_kernel_template,
|
||||
persistent_grouped_mm_grid,
|
||||
)
|
||||
|
||||
@ -513,6 +517,11 @@ triton_scaled_grouped_mm_template = TritonTemplate(
|
||||
source=triton_grouped_mm_source,
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template = CuteDSLTemplate(
|
||||
name="grouped_gemm_cutedsl",
|
||||
source=load_kernel_template("cutedsl_mm_grouped"),
|
||||
)
|
||||
|
||||
|
||||
def grouped_mm_args(
|
||||
mat1: TensorBox,
|
||||
@ -714,43 +723,44 @@ def _tuned_grouped_mm_common(
|
||||
# Checking only for the equality of corresponding dims of
|
||||
# multiplicands here, relying on meta function checks for
|
||||
# everything else.
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
if (
|
||||
is_nonzero
|
||||
and use_triton_template(layout)
|
||||
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
|
||||
):
|
||||
scaled = scale_a is not None
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
a_is_k_major = mat_a.get_stride()[-1] == 1
|
||||
b_is_k_major = mat_b.get_stride()[-2] == 1
|
||||
@ -788,6 +798,22 @@ def _tuned_grouped_mm_common(
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
if use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
|
||||
):
|
||||
for config in get_groupgemm_configs():
|
||||
kwargs = dict(
|
||||
ACC_DTYPE="cutlass.Float32",
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**asdict(config),
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
4: lambda x: create_offsets(
|
||||
x, m1_size, m2_size, offs.get_size() if offs is not None else None
|
||||
|
||||
333
torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja
Normal file
333
torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja
Normal file
@ -0,0 +1,333 @@
|
||||
import functools
|
||||
from torch._inductor.runtime.runtime_utils import ceildiv
|
||||
from cutlass.utils import TensorMapUpdateMode
|
||||
{{gen_defines()}}
|
||||
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
|
||||
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
|
||||
GroupedGemmKernel,
|
||||
)
|
||||
|
||||
|
||||
# Note about caching:
|
||||
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
|
||||
# maintains its own local caching system. At this stage, all compile-time
|
||||
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
|
||||
# name itself ({{kernel_name}}) are permanently baked into the file, so they
|
||||
# do not need to be included in any cache key.
|
||||
#
|
||||
# The caching mechanism is split into two levels:
|
||||
#
|
||||
# 1. prep_cache
|
||||
# Caches the compiled executor for build_group_ptrs_from_bases(). This
|
||||
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
|
||||
# and can therefore be safely reused across runs with different group
|
||||
# partitioning (`offs`).
|
||||
#
|
||||
# 2. gemm_cache
|
||||
# Caches the compiled Grouped GEMM executor. Its key extends the prep
|
||||
# cache key with hardware- and grid-specific parameters:
|
||||
# (prep_cache_key, max_active_clusters, total_num_clusters).
|
||||
# This is necessary because different `offs` tensors can change the
|
||||
# per-group problem sizes and thus alter `total_num_clusters`, which in
|
||||
# turn changes the grid shape and persistent scheduler configuration.
|
||||
# Kernels compiled for one grid cannot be safely reused for another.
|
||||
#
|
||||
#
|
||||
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
|
||||
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
|
||||
# despite depending only on the GPU type. We cache this function to mitigate
|
||||
# redundant recompiles even when shape/stride/dtype cache misses force kernel
|
||||
# regeneration. A follow-up study will investigate the root cause.
|
||||
|
||||
prep_cache = {}
|
||||
gemm_cache = {}
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_hardware_info():
|
||||
hw = cutlass.utils.HardwareInfo()
|
||||
sm_count = hw.get_max_active_clusters(1)
|
||||
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
|
||||
|
||||
return (sm_count, max_active_clusters)
|
||||
|
||||
|
||||
def get_prep_cache_key(input_a, input_b, output):
|
||||
"""
|
||||
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
|
||||
shapes, strides, and dtypes of input/output tensors.
|
||||
"""
|
||||
return (
|
||||
tuple(input_a.shape),
|
||||
tuple(input_a.stride()),
|
||||
input_a.dtype,
|
||||
tuple(input_b.shape),
|
||||
tuple(input_b.stride()),
|
||||
input_b.dtype,
|
||||
tuple(output.shape),
|
||||
tuple(output.stride()),
|
||||
output.dtype,
|
||||
)
|
||||
|
||||
|
||||
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
|
||||
"""
|
||||
Returns a tuple key for caching the gemm kernel executor by extending the
|
||||
prep cache key with hardware- and grid-specific parameters.
|
||||
"""
|
||||
return (
|
||||
prep_cache_key,
|
||||
max_active_clusters,
|
||||
total_num_clusters,
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
|
||||
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
|
||||
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
|
||||
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Int32, # bytes
|
||||
# -------- STRIDES (in ELEMENTS) --------
|
||||
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
|
||||
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
|
||||
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
|
||||
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
|
||||
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
|
||||
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
|
||||
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
|
||||
# -------- OUTPUTS --------
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
|
||||
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
g = tidx
|
||||
|
||||
m_beg_i32 = 0
|
||||
if g > 0:
|
||||
m_beg_i32 = offs[g - 1]
|
||||
m_end_i32 = offs[g]
|
||||
m_g_i32 = m_end_i32 - m_beg_i32
|
||||
|
||||
a_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
c_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
|
||||
|
||||
# ---- pointers ----
|
||||
out_ptrs[g, 0] = base_A_u64 + a_byte_off
|
||||
out_ptrs[g, 1] = base_B_u64 + b_byte_off
|
||||
out_ptrs[g, 2] = base_C_u64 + c_byte_off
|
||||
|
||||
# ---- (m, n, k, 1) ----
|
||||
out_problem[g, 0] = m_g_i32
|
||||
out_problem[g, 1] = N
|
||||
out_problem[g, 2] = K
|
||||
out_problem[g, 3] = cutlass.Int32(1)
|
||||
|
||||
# ---- strides ----
|
||||
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
|
||||
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
|
||||
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
|
||||
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
|
||||
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
|
||||
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_build_group_ptrs_from_bases(
|
||||
base_A_u64: cutlass.Int64,
|
||||
base_B_u64: cutlass.Int64,
|
||||
base_C_u64: cutlass.Int64,
|
||||
offs: cute.Tensor,
|
||||
G: cutlass.Constexpr,
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Constexpr,
|
||||
stride_A_m_elems: cutlass.Constexpr,
|
||||
stride_A_k_elems: cutlass.Constexpr,
|
||||
stride_B0_elems: cutlass.Constexpr,
|
||||
stride_Bk_elems: cutlass.Constexpr,
|
||||
stride_Bn_elems: cutlass.Constexpr,
|
||||
stride_C_m_elems: cutlass.Constexpr,
|
||||
stride_C_n_elems: cutlass.Constexpr,
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32
|
||||
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64,
|
||||
base_B_u64,
|
||||
base_C_u64,
|
||||
offs,
|
||||
K,
|
||||
N,
|
||||
sizeof_element,
|
||||
stride_A_m_elems,
|
||||
stride_A_k_elems,
|
||||
stride_B0_elems,
|
||||
stride_Bk_elems,
|
||||
stride_Bn_elems,
|
||||
stride_C_m_elems,
|
||||
stride_C_n_elems,
|
||||
out_ptrs,
|
||||
out_problem,
|
||||
out_strides_abc,
|
||||
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
|
||||
|
||||
|
||||
{{def_kernel("input_a", "input_b", "input_a_offs")}}
|
||||
stream = cuda.CUstream(stream)
|
||||
|
||||
input_b = input_b.transpose(1, 2)
|
||||
|
||||
sumM, K = input_a.shape
|
||||
G, N, Kb = input_b.shape
|
||||
|
||||
dev = input_a.device
|
||||
|
||||
base_A_u64 = int(input_a.data_ptr())
|
||||
base_B_u64 = int(input_b.data_ptr())
|
||||
base_C_u64 = int({{get_output()}}.data_ptr())
|
||||
|
||||
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
|
||||
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
|
||||
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
|
||||
ptrs = from_dlpack(ptrs_t)
|
||||
probs = from_dlpack(probs_t)
|
||||
strides = from_dlpack(strides_t)
|
||||
|
||||
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
|
||||
prep_executor = prep_cache.get(prep_cache_key)
|
||||
|
||||
if prep_executor is None:
|
||||
sizeof_element = int(input_a.element_size())
|
||||
sA_m, sA_k = map(int, input_a.stride())
|
||||
sB_0, sB_n, sB_k = map(int, input_b.stride())
|
||||
sC_m, sC_n = map(int, {{get_output()}}.stride())
|
||||
|
||||
prep_executor = cute.compile(
|
||||
launch_build_group_ptrs_from_bases,
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
G=int(G),
|
||||
K=int(K),
|
||||
N=int(N),
|
||||
sizeof_element=sizeof_element,
|
||||
stride_A_m_elems=sA_m,
|
||||
stride_A_k_elems=sA_k,
|
||||
stride_B0_elems=sB_0,
|
||||
stride_Bk_elems=sB_k,
|
||||
stride_Bn_elems=sB_n,
|
||||
stride_C_m_elems=sC_m,
|
||||
stride_C_n_elems=sC_n,
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
prep_cache[prep_cache_key] = prep_executor
|
||||
|
||||
prep_executor(
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# --- Tensormap workspace per SM ---
|
||||
num_tensormap_buffers, max_active_clusters = get_hardware_info()
|
||||
tensormap_shape = (
|
||||
num_tensormap_buffers,
|
||||
GroupedGemmKernel.num_tensormaps,
|
||||
GroupedGemmKernel.bytes_per_tensormap // 8,
|
||||
)
|
||||
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
|
||||
tensormap_workspace = from_dlpack(tensormap_workspace_t)
|
||||
|
||||
# --- Total clusters ---
|
||||
def compute_total_num_clusters(
|
||||
problem_sizes_mnkl,
|
||||
cluster_tile_shape_mn,
|
||||
):
|
||||
total_num_clusters = 0
|
||||
for m, n, _, _ in problem_sizes_mnkl:
|
||||
num_clusters_mn = tuple(
|
||||
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
|
||||
)
|
||||
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
|
||||
return total_num_clusters
|
||||
|
||||
# Compute cluster tile shape
|
||||
def compute_cluster_tile_shape(
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
use_2cta_instrs,
|
||||
):
|
||||
cta_tile_shape_mn = list(mma_tiler_mn)
|
||||
if use_2cta_instrs:
|
||||
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
|
||||
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
|
||||
|
||||
cluster_tile_shape_mn = compute_cluster_tile_shape(
|
||||
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
|
||||
)
|
||||
|
||||
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
|
||||
|
||||
gemm_cache_key = get_gemm_cache_key(
|
||||
prep_cache_key, max_active_clusters, total_num_clusters
|
||||
)
|
||||
gemm_executor = gemm_cache.get(gemm_cache_key)
|
||||
|
||||
if gemm_executor is None:
|
||||
grouped_gemm = GroupedGemmKernel(
|
||||
acc_dtype=ACC_DTYPE,
|
||||
use_2cta_instrs=USE_2_CTA,
|
||||
mma_tiler_mn=(TILE_M, TILE_N),
|
||||
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
|
||||
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
|
||||
)
|
||||
|
||||
gemm_executor = cute.compile(
|
||||
grouped_gemm,
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
G,
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
total_num_clusters,
|
||||
tensormap_workspace,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
|
||||
gemm_cache[gemm_cache_key] = gemm_executor
|
||||
|
||||
gemm_executor(
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
tensormap_workspace,
|
||||
stream,
|
||||
)
|
||||
141
torch/_inductor/template_heuristics/cutedsl.py
Normal file
141
torch/_inductor/template_heuristics/cutedsl.py
Normal file
@ -0,0 +1,141 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from itertools import product
|
||||
|
||||
import torch._inductor.config as config
|
||||
|
||||
|
||||
class TensorMapUpdateMode(Enum):
|
||||
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
|
||||
|
||||
SMEM = auto()
|
||||
GMEM = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CuTeGemmConfig:
|
||||
TILE_M: int = 128
|
||||
TILE_N: int = 192
|
||||
CLUSTER_M: int = 2
|
||||
CLUSTER_N: int = 1
|
||||
USE_2_CTA: bool = False
|
||||
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
|
||||
|
||||
|
||||
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
For information regarding valid config sets, see:
|
||||
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
|
||||
"""
|
||||
|
||||
# Tile_n is always the same regardless of 2cta
|
||||
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
# Valid clusters
|
||||
clusters_no_2cta = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(1, 8),
|
||||
(1, 16),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
clusters_2cta = [
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
|
||||
configs: list[CuTeGemmConfig] = []
|
||||
|
||||
for use_2cta, cluster_set, tile_m_range in [
|
||||
(False, clusters_no_2cta, [64, 128]),
|
||||
(True, clusters_2cta, [128, 256]),
|
||||
]:
|
||||
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
|
||||
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
|
||||
tile_m_range,
|
||||
tile_n_vals,
|
||||
cluster_set,
|
||||
):
|
||||
configs.append(
|
||||
CuTeGemmConfig(
|
||||
tile_m,
|
||||
tile_n,
|
||||
cluster_m,
|
||||
cluster_n,
|
||||
USE_2_CTA=use_2cta,
|
||||
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
"""
|
||||
|
||||
config_tuples = [
|
||||
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
]
|
||||
|
||||
return [CuTeGemmConfig(*args) for args in config_tuples]
|
||||
|
||||
|
||||
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
|
||||
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
|
||||
or unstable results. By default, autotuning is disabled and we return only
|
||||
a single baseline config.
|
||||
"""
|
||||
if (
|
||||
config.cutedsl_enable_autotuning
|
||||
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
):
|
||||
return get_exhaustive_groupgemm_configs()
|
||||
elif config.cutedsl_enable_autotuning:
|
||||
return get_default_groupgemm_configs()
|
||||
else:
|
||||
return [get_default_groupgemm_configs()[0]]
|
||||
@ -1911,6 +1911,84 @@ def use_triton_blackwell_tma_template(
|
||||
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def ensure_cute_available() -> bool:
|
||||
"""Check if CuTeDSL is importable; cache the result for reuse.
|
||||
|
||||
Call ensure_cute_available.cache_clear() after installing CuTeDSL
|
||||
in the same interpreter to retry the import.
|
||||
"""
|
||||
try:
|
||||
return importlib.util.find_spec("cutlass.cute") is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a: Any,
|
||||
mat_b: Any,
|
||||
layout: Layout,
|
||||
a_is_2d: bool,
|
||||
b_is_2d: bool,
|
||||
offs: Optional[Any],
|
||||
bias: Optional[Any],
|
||||
scale_result: Optional[Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we can use the blackwell kernel for grouped mm.
|
||||
Required conditions:
|
||||
1. CuTeDSL backend is enabled
|
||||
2. CuTeDSL is available
|
||||
3. We are on a blackwell arch
|
||||
4. The dtype is bf16
|
||||
5. Max autotune or max autotune gemm is enabled
|
||||
6. A, B, and the output are 16B aligned
|
||||
7. We are not using dynamic shapes
|
||||
8. A is 2d
|
||||
9. B is 3d
|
||||
10. Offsets are provided
|
||||
11. Bias and Scale are not provided
|
||||
"""
|
||||
if not ensure_cute_available():
|
||||
return False
|
||||
|
||||
if not _use_autotune_backend("CUTEDSL"):
|
||||
return False
|
||||
|
||||
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
|
||||
if not is_gpu(layout.device.type):
|
||||
return False
|
||||
|
||||
if not is_datacenter_blackwell_arch():
|
||||
return False
|
||||
|
||||
layout_dtypes = [torch.bfloat16]
|
||||
if not _use_template_for_gpu(layout, layout_dtypes):
|
||||
return False
|
||||
|
||||
if not (config.max_autotune or config.max_autotune_gemm):
|
||||
return False
|
||||
|
||||
# Checks for 16B ptr and stride alignment
|
||||
if not can_use_tma(mat_a, mat_b, output_layout=layout):
|
||||
return False
|
||||
|
||||
if any(is_dynamic(x) for x in [mat_a, mat_b]):
|
||||
return False
|
||||
|
||||
if not a_is_2d or b_is_2d:
|
||||
return False
|
||||
|
||||
if offs is None:
|
||||
return False
|
||||
|
||||
if bias is not None or scale_result is not None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
from .virtualized import V
|
||||
|
||||
@ -2651,7 +2729,6 @@ def pass_execution_and_save(
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
delete=False,
|
||||
) as f:
|
||||
before_io = io.StringIO()
|
||||
after_io = io.StringIO()
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
from ._utils import _device_t, _get_device_index
|
||||
from .memory import (
|
||||
empty_cache,
|
||||
get_memory_info,
|
||||
max_memory_allocated,
|
||||
max_memory_reserved,
|
||||
memory_allocated,
|
||||
@ -25,9 +26,10 @@ __all__ = [
|
||||
"current_device_idx", # deprecated
|
||||
"current_device_index",
|
||||
"current_stream",
|
||||
"empty_cache",
|
||||
"device_count",
|
||||
"device_index",
|
||||
"empty_cache",
|
||||
"get_memory_info",
|
||||
"is_available",
|
||||
"max_memory_allocated",
|
||||
"max_memory_reserved",
|
||||
|
||||
@ -8,6 +8,7 @@ from ._utils import _device_t, _get_device_index
|
||||
|
||||
__all__ = [
|
||||
"empty_cache",
|
||||
"get_memory_info",
|
||||
"max_memory_allocated",
|
||||
"max_memory_reserved",
|
||||
"memory_allocated",
|
||||
@ -87,6 +88,9 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values.
|
||||
"""
|
||||
if not torch._C._accelerator_isAllocatorInitialized():
|
||||
return OrderedDict()
|
||||
@ -117,6 +121,9 @@ def memory_allocated(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the current memory occupied by live tensors (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("allocated_bytes.all.current", 0)
|
||||
|
||||
@ -134,6 +141,9 @@ def max_memory_allocated(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the peak memory occupied by live tensors (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
|
||||
|
||||
@ -147,6 +157,9 @@ def memory_reserved(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the current memory reserved by PyTorch (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("reserved_bytes.all.current", 0)
|
||||
|
||||
@ -164,6 +177,9 @@ def max_memory_reserved(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the peak memory reserved by PyTorch (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
|
||||
|
||||
@ -200,3 +216,21 @@ def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
|
||||
"""
|
||||
device_index = _get_device_index(device_index, optional=True)
|
||||
return torch._C._accelerator_resetPeakStats(device_index)
|
||||
|
||||
|
||||
def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]:
|
||||
r"""Return the current device memory information for a given device index.
|
||||
|
||||
Args:
|
||||
device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes.
|
||||
The first value is the free memory on the device (available across all processes and applications),
|
||||
The second value is the device's total hardware memory capacity.
|
||||
"""
|
||||
device_index = _get_device_index(device_index, optional=True)
|
||||
return torch._C._accelerator_getMemoryInfo(device_index)
|
||||
|
||||
@ -195,10 +195,12 @@ def get_new_attr_name_with_prefix(prefix: str) -> Callable:
|
||||
def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
|
||||
r"""Starting from a target node, trace back until we hit input or
|
||||
getattr node. This is used to extract the chain of operators
|
||||
starting from getattr to the target node, for example
|
||||
def forward(self, x):
|
||||
observed = self.observer(self.weight)
|
||||
return F.linear(x, observed)
|
||||
starting from getattr to the target node, for example::
|
||||
|
||||
def forward(self, x):
|
||||
observed = self.observer(self.weight)
|
||||
return F.linear(x, observed)
|
||||
|
||||
collect_producer_nodes(observed) will either return a list of nodes that
|
||||
produces the observed node or None if we can't extract a self contained
|
||||
graph without free variables(inputs of the forward function).
|
||||
|
||||
@ -138,6 +138,13 @@ void initModule(PyObject* module) {
|
||||
at::accelerator::resetPeakStats(device_index);
|
||||
});
|
||||
|
||||
m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
||||
const auto device_type = at::accelerator::getAccelerator(true).value();
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
py::gil_scoped_release no_gil;
|
||||
return at::accelerator::getMemoryInfo(device_index);
|
||||
});
|
||||
|
||||
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
});
|
||||
|
||||
@ -122,29 +122,47 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT
|
||||
ID get_tensor_storage_ID(const c10::Storage& t_storage) {
|
||||
const std::lock_guard<std::recursive_mutex> lock(gMutex);
|
||||
|
||||
const void* raw_data_ptr = t_storage.data();
|
||||
auto iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
|
||||
if (iter == data_ptr_to_weak_storage_ptr.end()) {
|
||||
const void* raw_data_ptr = nullptr;
|
||||
bool should_track_liveness = false;
|
||||
// FakeTensor/FunctionalTensor may clear the Storage handle entirely or use
|
||||
// a nullptr data pointer. Treat both cases as a shared cache key but avoid
|
||||
// touching the weak-ref table so they can reuse the same ID without
|
||||
// tripping the liveness check.
|
||||
if (t_storage.unsafeGetStorageImpl()) {
|
||||
raw_data_ptr = t_storage.data();
|
||||
should_track_liveness = raw_data_ptr != nullptr;
|
||||
}
|
||||
|
||||
auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
|
||||
if (!should_track_liveness) {
|
||||
if (id_iter != data_ptr_to_storage_id.end()) {
|
||||
return id_iter->second;
|
||||
}
|
||||
ID id = storage_id_++;
|
||||
data_ptr_to_storage_id.emplace(raw_data_ptr, id);
|
||||
return id;
|
||||
}
|
||||
|
||||
auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr);
|
||||
if (weak_iter == data_ptr_to_weak_storage_ptr.end()) {
|
||||
ID id = storage_id_++;
|
||||
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
|
||||
data_ptr_to_weak_storage_ptr.emplace(
|
||||
raw_data_ptr, t_storage.getWeakStorageImpl());
|
||||
return id;
|
||||
} else {
|
||||
// check if the storage is still alive
|
||||
if (iter->second.expired()) {
|
||||
ID id = storage_id_++;
|
||||
// std::unorder_map does not change if the key is already in the map.
|
||||
// So we need to remove the key and insert the key with the new value.
|
||||
data_ptr_to_storage_id.erase(raw_data_ptr);
|
||||
data_ptr_to_storage_id[raw_data_ptr] = id;
|
||||
data_ptr_to_weak_storage_ptr.insert_or_assign(
|
||||
raw_data_ptr, t_storage.getWeakStorageImpl());
|
||||
return id;
|
||||
} else {
|
||||
return data_ptr_to_storage_id[raw_data_ptr];
|
||||
}
|
||||
}
|
||||
|
||||
if (weak_iter->second.expired()) {
|
||||
ID id = storage_id_++;
|
||||
data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id);
|
||||
data_ptr_to_weak_storage_ptr.insert_or_assign(
|
||||
raw_data_ptr, t_storage.getWeakStorageImpl());
|
||||
return id;
|
||||
}
|
||||
|
||||
id_iter = data_ptr_to_storage_id.find(raw_data_ptr);
|
||||
TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end());
|
||||
return id_iter->second;
|
||||
}
|
||||
|
||||
// Observer run state.
|
||||
|
||||
@ -386,23 +386,8 @@ static void bindGetDeviceProperties(PyObject* module) {
|
||||
static void initXpuMethodBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
||||
#if SYCL_COMPILER_VERSION >= 20250000
|
||||
auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size;
|
||||
auto& device = c10::xpu::get_raw_device(device_index);
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
at::xpu::getDeviceProperties(device_index)->name,
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
auto free = device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return std::make_tuple(free, total);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||
#endif
|
||||
py::gil_scoped_release no_gil;
|
||||
return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index);
|
||||
});
|
||||
m.def(
|
||||
"_xpu_getStreamFromExternal",
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
@ -7,6 +8,7 @@ import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
_StridedShard,
|
||||
MaskPartial,
|
||||
Partial,
|
||||
Placement,
|
||||
Replicate,
|
||||
@ -127,6 +129,185 @@ class DTensorSpec:
|
||||
)
|
||||
return default_shard_order
|
||||
|
||||
@staticmethod
|
||||
def _convert_shard_order_to_StridedShard(
|
||||
shard_order: ShardOrder, placements: tuple[Placement, ...], mesh: DeviceMesh
|
||||
) -> tuple[Placement, ...]:
|
||||
"""
|
||||
Convert ShardOrder to placements with _StridedShard.
|
||||
|
||||
This function converts a ShardOrder specification into a tuple of Placement objects,
|
||||
using _StridedShard when a tensor dimension is sharded across multiple mesh dimensions
|
||||
in a non-default order. The split_factor of each _StridedShard is determined by the
|
||||
product of mesh dimension sizes that appear earlier in the shard order but later in
|
||||
the placement tuple.
|
||||
|
||||
Args:
|
||||
shard_order: ShardOrder specification indicating which tensor dimensions are
|
||||
sharded on which mesh dimensions and in what execution order.
|
||||
placements: Tuple of Placement objects that does not contain _StridedShard.
|
||||
mesh: DeviceMesh containing the size information for each mesh dimension.
|
||||
|
||||
Returns:
|
||||
Updated tuple of Placement objects with Shard or _StridedShard placements.
|
||||
|
||||
Algorithm:
|
||||
For each ShardOrderEntry in shard_order:
|
||||
- For each mesh dimension in the entry's mesh_dims (in order):
|
||||
- Calculate split_factor as the product of mesh sizes for all mesh dimensions
|
||||
that appear:
|
||||
1. Earlier in the shard order (lower index in mesh_dims), and
|
||||
2. Later in the placement tuple (higher mesh dimension index)
|
||||
- If split_factor == 1: use normal Shard
|
||||
- Otherwise: use _StridedShard with the calculated split_factor
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("Requires DeviceMesh")
|
||||
>>> # Tensor dimension 0 sharded on mesh dims [2, 0, 1] in that order
|
||||
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
|
||||
>>> shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),)
|
||||
>>> placements = (Shard(0), Shard(0), Shard(0))
|
||||
>>> # For mesh_dim=2 (index 0 in mesh_dims): no earlier dims, split_factor=1
|
||||
>>> # -> placements[2] = Shard(0)
|
||||
>>> # For mesh_dim=0 (index 1 in mesh_dims): mesh_dim=2 is earlier and has index 2>0
|
||||
>>> # -> split_factor = mesh.size(2) = 2
|
||||
>>> # -> placements[0] = _StridedShard(0, split_factor=2)
|
||||
>>> # For mesh_dim=1 (index 2 in mesh_dims): mesh_dim=2 is earlier and has index 2>1
|
||||
>>> # -> split_factor = mesh.size(2) = 2
|
||||
>>> # -> placements[1] = _StridedShard(0, split_factor=2)
|
||||
>>> # Result: (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
|
||||
"""
|
||||
placements_list = list(placements)
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
for idx in range(len(mesh_dims)):
|
||||
# TODO(zpcore): split_factor from `view` and `shard order`
|
||||
# should be able to be multiplied into one. Need to loosen the
|
||||
# condition here.
|
||||
mesh_dim = mesh_dims[idx]
|
||||
if type(placements[mesh_dim]) is not Shard:
|
||||
raise ValueError(
|
||||
f"Only Shard placement can be converted to _StridedShard, "
|
||||
f"found {placements[mesh_dim]} in {placements=}."
|
||||
)
|
||||
split_factor = math.prod(
|
||||
mesh.size(i) for i in mesh_dims[:idx] if i > mesh_dim
|
||||
)
|
||||
if split_factor == 1:
|
||||
# use normal Shard
|
||||
placements_list[mesh_dim] = Shard(tensor_dim)
|
||||
else:
|
||||
placements_list[mesh_dim] = _StridedShard(
|
||||
tensor_dim, split_factor=split_factor
|
||||
)
|
||||
return tuple(placements_list)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_convert_StridedShard_to_shard_order(
|
||||
placements: tuple[Placement, ...], mesh: DeviceMesh
|
||||
) -> Optional[ShardOrder]:
|
||||
"""
|
||||
Try to convert _StridedShard placements to ShardOrder.
|
||||
|
||||
This is the inverse of `_convert_shard_order_to_StridedShard`. It reconstructs the shard
|
||||
order by examining the split_factor of each _StridedShard and determining its position
|
||||
in the execution order. If the _StridedShard configuration cannot be represented as a
|
||||
valid ShardOrder (i.e., there's no shard order that produces the observed split_factors),
|
||||
this function returns None.
|
||||
|
||||
Args:
|
||||
placements: Tuple of Placement objects that may contain _StridedShard.
|
||||
mesh: DeviceMesh containing the size information for each mesh dimension.
|
||||
|
||||
Returns:
|
||||
ShardOrder if conversion is possible, None otherwise. For placements without
|
||||
_StridedShard, returns the default shard order.
|
||||
|
||||
Algorithm:
|
||||
1. If no _StridedShard in placements, return default shard order
|
||||
2. Create an empty list for each tensor dimension to represent mesh dim ordering
|
||||
3. Iterate through placements in reverse order (right to left):
|
||||
- For each Shard/_StridedShard on a tensor dimension:
|
||||
- Extract its split_factor (1 for Shard, split_factor for _StridedShard)
|
||||
- Find the position in mesh_dims_order where accumulated_sf equals split_factor
|
||||
- accumulated_sf is the product of mesh sizes of mesh dimensions that appear
|
||||
earlier in mesh_dims_order (lower indices)
|
||||
- Insert mesh_dim at the found position
|
||||
4. If no valid position found for any split_factor, return None (unable to convert)
|
||||
5. Construct ShardOrderEntry for each tensor dimension from mesh_dims_order
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("Requires DeviceMesh")
|
||||
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
|
||||
>>> # placements = (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
|
||||
>>> # Process tensor_dim=0 from right to left:
|
||||
>>> # - mesh_dim=2: Shard(0) with sf=1
|
||||
>>> # Try position 0: accumulated_sf=1, matches! Insert at position 0
|
||||
>>> # Current mesh_dims_order order: [2]
|
||||
>>> # - mesh_dim=1: _StridedShard(0, sf=2) with sf=2
|
||||
>>> # Try position 0: accumulated_sf=1, no match
|
||||
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
|
||||
>>> # Current mesh_dims_order order: [2, 1]
|
||||
>>> # - mesh_dim=0: _StridedShard(0, sf=2) with sf=2
|
||||
>>> # Try position 0: accumulated_sf=1, no match
|
||||
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
|
||||
>>> # Final mesh_dims_order order: [2, 0, 1]
|
||||
>>> # Result: ShardOrder((ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),))
|
||||
>>> # This means: first shard on mesh_dim=2, then mesh_dim=0, then mesh_dim=1
|
||||
|
||||
Note:
|
||||
This function validates that _StridedShard can be represented as a ShardOrder.
|
||||
Not all _StridedShard configurations are valid - the split_factor must match
|
||||
the product of mesh sizes in some execution order.
|
||||
"""
|
||||
if not any(isinstance(p, _StridedShard) for p in placements):
|
||||
return DTensorSpec.compute_default_shard_order(placements)
|
||||
max_tensor_dim = (
|
||||
max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1
|
||||
)
|
||||
shard_order = []
|
||||
|
||||
tensor_dim_to_mesh_dims_order: list[list[int]] = [
|
||||
[] for i in range(max_tensor_dim)
|
||||
]
|
||||
for mesh_dim in reversed(range(len(placements))):
|
||||
cur_placement = placements[mesh_dim]
|
||||
# _StridedShard may not be a subclass of Shard in the future, so write in this way:
|
||||
if isinstance(cur_placement, Shard | _StridedShard):
|
||||
tensor_dim = cur_placement.dim
|
||||
mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim]
|
||||
cur_sf = 1
|
||||
if isinstance(cur_placement, _StridedShard):
|
||||
cur_sf = cur_placement.split_factor
|
||||
accumulated_sf = 1
|
||||
find_order = False
|
||||
for i in range(len(mesh_dims_order) + 1):
|
||||
if accumulated_sf == cur_sf:
|
||||
mesh_dims_order.insert(i, mesh_dim)
|
||||
find_order = True
|
||||
break
|
||||
if i < len(mesh_dims_order):
|
||||
accumulated_sf *= mesh.size(mesh_dims_order[i])
|
||||
if not find_order:
|
||||
# _StridedShard is not convertible to ShardOrder
|
||||
return None
|
||||
else:
|
||||
if not isinstance(cur_placement, Replicate | Partial | MaskPartial):
|
||||
raise ValueError(
|
||||
f"Unsupported placement type {type(cur_placement)} encountered in "
|
||||
f"{placements}; expected Replicate, Partial, or MaskPartial."
|
||||
)
|
||||
for tensor_dim in range(max_tensor_dim):
|
||||
if len(tensor_dim_to_mesh_dims_order[tensor_dim]) > 0:
|
||||
shard_order.append(
|
||||
ShardOrderEntry(
|
||||
tensor_dim=tensor_dim,
|
||||
mesh_dims=tuple(tensor_dim_to_mesh_dims_order[tensor_dim]),
|
||||
)
|
||||
)
|
||||
return tuple(shard_order)
|
||||
|
||||
def _verify_shard_order(self, shard_order: ShardOrder) -> None:
|
||||
"""Verify that the shard_order is valid and matches the placements."""
|
||||
total_shard = 0
|
||||
|
||||
@ -11,7 +11,10 @@ import torch.distributed.tensor._api as dtensor
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def _requires_data_exchange(padding):
|
||||
def _requires_data_exchange(padding, dim_map) -> bool:
|
||||
# Data exchange is not need if only sharded across batch dim
|
||||
if all(x == -1 for x in dim_map[1:]):
|
||||
return False
|
||||
# TODO: whether there requires data exchange is currently determined by padding
|
||||
return padding[-1] != 0
|
||||
|
||||
@ -107,6 +110,7 @@ def tp_convolution(
|
||||
op_call: torch._ops.OpOverload,
|
||||
local_tensor_args: tuple[object, ...],
|
||||
local_tensor_kwargs: dict[str, object],
|
||||
dim_map: list[int],
|
||||
) -> object:
|
||||
assert op_call == aten.convolution.default
|
||||
assert len(local_tensor_args) == 9
|
||||
@ -120,7 +124,7 @@ def tp_convolution(
|
||||
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
|
||||
assert isinstance(padding, list)
|
||||
|
||||
if not _requires_data_exchange(padding):
|
||||
if not _requires_data_exchange(padding, dim_map):
|
||||
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
|
||||
return local_results
|
||||
else:
|
||||
@ -160,6 +164,7 @@ def tp_convolution_backward(
|
||||
op_call: torch._ops.OpOverload,
|
||||
local_tensor_args: tuple[object, ...],
|
||||
local_tensor_kwargs: dict[str, object],
|
||||
dim_map: list[int],
|
||||
) -> object:
|
||||
assert op_call == aten.convolution_backward.default
|
||||
assert len(local_tensor_args) == 11
|
||||
@ -174,7 +179,7 @@ def tp_convolution_backward(
|
||||
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
|
||||
assert isinstance(padding, list)
|
||||
|
||||
if not _requires_data_exchange(padding):
|
||||
if not _requires_data_exchange(padding, dim_map):
|
||||
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
|
||||
return local_results
|
||||
else:
|
||||
@ -239,15 +244,18 @@ def convolution_handler(
|
||||
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
|
||||
output_sharding = op_info.output_sharding
|
||||
assert output_sharding is not None, "output sharding should not be None"
|
||||
output_spec = output_sharding.output_spec
|
||||
assert isinstance(output_spec, dtensor.DTensorSpec)
|
||||
|
||||
# local propagation
|
||||
local_results = tp_convolution(
|
||||
op_call, tuple(op_info.local_args), op_info.local_kwargs
|
||||
op_call,
|
||||
tuple(op_info.local_args),
|
||||
op_info.local_kwargs,
|
||||
output_spec.dim_map,
|
||||
)
|
||||
|
||||
return dtensor.DTensor._op_dispatcher.wrap(
|
||||
local_results, output_sharding.output_spec
|
||||
)
|
||||
return dtensor.DTensor._op_dispatcher.wrap(local_results, output_spec)
|
||||
|
||||
|
||||
def convolution_backward_handler(
|
||||
@ -270,10 +278,14 @@ def convolution_backward_handler(
|
||||
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
|
||||
output_sharding = op_info.output_sharding
|
||||
assert output_sharding is not None, "output sharding should not be None"
|
||||
assert isinstance(op_info.flat_args_schema[0], dtensor.DTensorSpec)
|
||||
|
||||
# local propagation
|
||||
local_results = tp_convolution_backward(
|
||||
op_call, tuple(op_info.local_args), op_info.local_kwargs
|
||||
op_call,
|
||||
tuple(op_info.local_args),
|
||||
op_info.local_kwargs,
|
||||
op_info.flat_args_schema[0].dim_map,
|
||||
)
|
||||
|
||||
return dtensor.DTensor._op_dispatcher.wrap(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
|
||||
number = Union[int, float]
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
# ${generated_comment}
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Literal, overload
|
||||
from typing_extensions import TypeAlias
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Literal, overload, TypeAlias
|
||||
|
||||
from torch import Tensor
|
||||
from torch.types import _dtype, _int, _size
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Sequence, TYPE_CHECKING
|
||||
from typing import Optional, Sequence, TYPE_CHECKING # noqa: UP035
|
||||
|
||||
from onnxscript.onnx_opset import ( # type: ignore[attr-defined]
|
||||
opset20 as op20,
|
||||
|
||||
@ -271,13 +271,11 @@ class _KinetoProfile:
|
||||
"Profiler must be initialized before exporting chrome trace"
|
||||
)
|
||||
if path.endswith(".gz"):
|
||||
with tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) as fp:
|
||||
fp.close()
|
||||
with tempfile.NamedTemporaryFile("w+b", suffix=".json") as fp:
|
||||
retvalue = self.profiler.export_chrome_trace(fp.name)
|
||||
with open(fp.name, "rb") as fin:
|
||||
with gzip.open(path, "wb") as fout:
|
||||
fout.writelines(fin)
|
||||
os.remove(fp.name)
|
||||
fp.seek(0)
|
||||
with gzip.open(path, "wb") as fout:
|
||||
fout.writelines(fp)
|
||||
return retvalue
|
||||
else:
|
||||
return self.profiler.export_chrome_trace(path)
|
||||
@ -448,15 +446,14 @@ class _KinetoProfile:
|
||||
if path.endswith(".html"):
|
||||
self.mem_tl.export_memory_timeline_html(path, device)
|
||||
elif path.endswith(".gz"):
|
||||
fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False)
|
||||
fp.close()
|
||||
if path.endswith("raw.json.gz"):
|
||||
self.mem_tl.export_memory_timeline_raw(fp.name, device)
|
||||
else:
|
||||
self.mem_tl.export_memory_timeline(fp.name, device)
|
||||
with open(fp.name) as fin, gzip.open(path, "wt") as fout:
|
||||
fout.writelines(fin)
|
||||
os.remove(fp.name)
|
||||
with tempfile.NamedTemporaryFile("w+t", suffix=".json") as fp:
|
||||
fp.close()
|
||||
if path.endswith("raw.json.gz"):
|
||||
self.mem_tl.export_memory_timeline_raw(fp.name, device)
|
||||
else:
|
||||
self.mem_tl.export_memory_timeline(fp.name, device)
|
||||
with open(fp.name) as fin, gzip.open(path, "wt") as fout:
|
||||
fout.writelines(fin)
|
||||
else:
|
||||
self.mem_tl.export_memory_timeline(path, device)
|
||||
|
||||
@ -946,7 +943,7 @@ class ExecutionTraceObserver(_ITraceObserver):
|
||||
"""
|
||||
if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE", "0") == "1":
|
||||
try:
|
||||
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
|
||||
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) # noqa:SIM115
|
||||
except Exception as e:
|
||||
warn(
|
||||
f"Execution trace will not be recorded. Exception on creating default temporary file: {e}",
|
||||
|
||||
@ -20320,6 +20320,7 @@ op_db: list[OpInfo] = [
|
||||
torch.float32: 1e-4}),),
|
||||
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
|
||||
supports_sparse=True,
|
||||
supports_sparse_csr=True,
|
||||
supports_sparse_csc=True,
|
||||
supports_sparse_bsr=True,
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import sys
|
||||
@ -32,6 +33,8 @@ from torch.distributed.tensor import (
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
@ -818,3 +821,125 @@ def map_local_for_rank(rank, func):
|
||||
|
||||
def reduce_local_int(val, func):
|
||||
return func(val.node._local_ints)
|
||||
|
||||
|
||||
def _convert_shard_order_dict_to_ShardOrder(shard_order):
|
||||
"""Convert shard_order dict to ShardOrder"""
|
||||
return tuple(
|
||||
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
|
||||
for tensor_dim, mesh_dims in shard_order.items()
|
||||
)
|
||||
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def redistribute(
|
||||
dtensor_input,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""
|
||||
wrapper function to support shard_order for redistribution
|
||||
This is a simpler version of Redistribute, only considers the forward.
|
||||
"""
|
||||
if placements is None:
|
||||
placements = shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
old_spec = dtensor_input._spec
|
||||
new_spec = copy.deepcopy(old_spec)
|
||||
new_spec.placements = placements
|
||||
if shard_order is not None:
|
||||
new_spec.shard_order = shard_order
|
||||
else:
|
||||
new_spec.shard_order = ()
|
||||
if old_spec == new_spec:
|
||||
return dtensor_input
|
||||
dtensor_input = DTensor.from_local(
|
||||
redistribute_local_tensor(
|
||||
dtensor_input.to_local(),
|
||||
old_spec,
|
||||
new_spec,
|
||||
use_graph_based_transform=use_graph_based_transform,
|
||||
),
|
||||
device_mesh,
|
||||
)
|
||||
dtensor_input._spec = copy.deepcopy(new_spec)
|
||||
return dtensor_input # returns DTensor
|
||||
|
||||
|
||||
# TODO(zpcore): remove once the native distribute_tensor supports
|
||||
# shard_order arg
|
||||
def patched_distribute_tensor(
|
||||
input_tensor,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""wrapper function to support shard_order for tensor distribution"""
|
||||
if placements is None:
|
||||
placements = shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
|
||||
# fix the shard order
|
||||
return redistribute(
|
||||
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
|
||||
)
|
||||
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def make_full_tensor(dtensor_input):
|
||||
"""wrapper function to support DTensor.full_tensor"""
|
||||
return redistribute(
|
||||
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
|
||||
).to_local()
|
||||
|
||||
|
||||
def shard_order_to_placement(shard_order, mesh):
|
||||
"""convert shard_order to placement with only Replicate() and Shard()"""
|
||||
placements: list[Any] = [Replicate() for _ in range(mesh.ndim)]
|
||||
if shard_order is not None:
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
for mesh_dim in mesh_dims:
|
||||
placements[mesh_dim] = Shard(tensor_dim)
|
||||
return tuple(placements)
|
||||
|
||||
|
||||
def generate_shard_orders(mesh, tensor_rank):
|
||||
# Generate all possible sharding placement of tensor with rank
|
||||
# `tensor_rank` over mesh.
|
||||
def _split_list(lst: list, N: int):
|
||||
def compositions(n: int, k: int):
|
||||
# yields lists of length k, positive ints summing to n
|
||||
for cuts in itertools.combinations(range(1, n), k - 1):
|
||||
# add 0 and n as sentinels, then take consecutive differences
|
||||
yield [b - a for a, b in itertools.pairwise((0, *cuts, n))]
|
||||
|
||||
length = len(lst)
|
||||
for comp in compositions(length, N):
|
||||
result = []
|
||||
start = 0
|
||||
for size in comp:
|
||||
result.append(lst[start : start + size])
|
||||
start += size
|
||||
yield result
|
||||
|
||||
all_mesh = list(range(mesh.ndim))
|
||||
all_device_order = list(itertools.permutations(all_mesh))
|
||||
for device_order in all_device_order:
|
||||
# split on device orders, and assign each device order segment to a tensor dim
|
||||
for num_split in range(1, mesh.ndim + 1):
|
||||
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
|
||||
for tensor_dims in itertools.combinations(
|
||||
range(tensor_rank), len(splitted_list)
|
||||
):
|
||||
shard_order = {}
|
||||
assert len(tensor_dims) == len(splitted_list)
|
||||
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
|
||||
shard_order[tensor_dim] = device_order[
|
||||
mesh_dims[0] : mesh_dims[-1] + 1
|
||||
]
|
||||
yield _convert_shard_order_dict_to_ShardOrder(shard_order)
|
||||
|
||||
@ -215,19 +215,16 @@ def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False):
|
||||
def get_profiler_nccl_meta(prof):
|
||||
"""Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
|
||||
We will need to test metadata obtained from profiler here"""
|
||||
tf = tempfile.NamedTemporaryFile(mode="w+t", suffix=".json", delete=False)
|
||||
tf.close()
|
||||
trace_file = tf.name
|
||||
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".json") as tf:
|
||||
tf.close()
|
||||
trace_file = tf.name
|
||||
|
||||
prof.export_chrome_trace(trace_file)
|
||||
with open(trace_file) as f:
|
||||
events = json.load(f)["traceEvents"]
|
||||
print(f"Trace saved to {trace_file}")
|
||||
prof.export_chrome_trace(trace_file)
|
||||
with open(trace_file) as f:
|
||||
events = json.load(f)["traceEvents"]
|
||||
print(f"Trace saved to {trace_file}")
|
||||
|
||||
# Comment to debug
|
||||
os.remove(trace_file)
|
||||
|
||||
return [e for e in events if e.get("name") == "record_param_comms"]
|
||||
return [e for e in events if e.get("name") == "record_param_comms"]
|
||||
|
||||
|
||||
# Base error message substring on unfinished reductions.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional # noqa: UP035
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ import unittest
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from types import FunctionType, ModuleType
|
||||
from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import deprecated
|
||||
from unittest import mock
|
||||
|
||||
@ -23,7 +23,7 @@ CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
|
||||
|
||||
|
||||
# Duplicated, because mypy needs these types statically
|
||||
T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict])
|
||||
T = TypeVar("T", bound=int | float | bool | None | str | list | set | tuple | dict)
|
||||
|
||||
|
||||
_UNSET_SENTINEL = object()
|
||||
@ -69,12 +69,12 @@ class _Config(Generic[T]):
|
||||
default behaviour. I.e. user overrides take preference.
|
||||
"""
|
||||
|
||||
default: Union[T, object]
|
||||
justknob: Optional[str] = None
|
||||
env_name_default: Optional[list[str]] = None
|
||||
env_name_force: Optional[list[str]] = None
|
||||
value_type: Optional[type] = None
|
||||
alias: Optional[str] = None
|
||||
default: T | object
|
||||
justknob: str | None = None
|
||||
env_name_default: list[str] | None = None
|
||||
env_name_force: list[str] | None = None
|
||||
value_type: type | None = None
|
||||
alias: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.env_name_default = _Config.string_or_list_of_string_to_list(
|
||||
@ -98,8 +98,8 @@ class _Config(Generic[T]):
|
||||
|
||||
@staticmethod
|
||||
def string_or_list_of_string_to_list(
|
||||
val: Optional[Union[str, list[str]]],
|
||||
) -> Optional[list[str]]:
|
||||
val: str | list[str] | None,
|
||||
) -> list[str] | None:
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, str):
|
||||
@ -116,23 +116,23 @@ class _Config(Generic[T]):
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def Config(
|
||||
default: Union[T, object] = _UNSET_SENTINEL,
|
||||
justknob: Optional[str] = None,
|
||||
env_name_default: Optional[Union[str, list[str]]] = None,
|
||||
env_name_force: Optional[Union[str, list[str]]] = None,
|
||||
value_type: Optional[type] = None,
|
||||
alias: Optional[str] = None,
|
||||
default: T | object = _UNSET_SENTINEL,
|
||||
justknob: str | None = None,
|
||||
env_name_default: str | list[str] | None = None,
|
||||
env_name_force: str | list[str] | None = None,
|
||||
value_type: type | None = None,
|
||||
alias: str | None = None,
|
||||
) -> T: ...
|
||||
|
||||
else:
|
||||
|
||||
def Config(
|
||||
default: Union[T, object] = _UNSET_SENTINEL,
|
||||
justknob: Optional[str] = None,
|
||||
env_name_default: Optional[Union[str, list[str]]] = None,
|
||||
env_name_force: Optional[Union[str, list[str]]] = None,
|
||||
value_type: Optional[type] = None,
|
||||
alias: Optional[str] = None,
|
||||
default: T | object = _UNSET_SENTINEL,
|
||||
justknob: str | None = None,
|
||||
env_name_default: str | list[str] | None = None,
|
||||
env_name_force: str | list[str] | None = None,
|
||||
value_type: type | None = None,
|
||||
alias: str | None = None,
|
||||
) -> _Config[T]:
|
||||
return _Config(
|
||||
default=default,
|
||||
@ -144,7 +144,7 @@ else:
|
||||
)
|
||||
|
||||
|
||||
def _read_env_variable(name: str) -> Optional[Union[bool, str]]:
|
||||
def _read_env_variable(name: str) -> bool | str | None:
|
||||
value = os.environ.get(name)
|
||||
if value == "1":
|
||||
return True
|
||||
@ -165,8 +165,8 @@ def install_config_module(module: ModuleType) -> None:
|
||||
_bypass_keys = set({"_is_dirty", "_hash_digest", "__annotations__"})
|
||||
|
||||
def visit(
|
||||
source: Union[ModuleType, type],
|
||||
dest: Union[ModuleType, SubConfigProxy],
|
||||
source: ModuleType | type,
|
||||
dest: ModuleType | SubConfigProxy,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
"""Walk the module structure and move everything to module._config"""
|
||||
@ -281,7 +281,7 @@ class _ConfigEntry:
|
||||
# _UNSET_SENTINEL indicates the value is not set.
|
||||
user_override: Any = _UNSET_SENTINEL
|
||||
# The justknob to check for this config
|
||||
justknob: Optional[str] = None
|
||||
justknob: str | None = None
|
||||
# environment variables are read at install time
|
||||
env_value_force: Any = _UNSET_SENTINEL
|
||||
env_value_default: Any = _UNSET_SENTINEL
|
||||
@ -297,7 +297,7 @@ class _ConfigEntry:
|
||||
# call so the final state is correct. It's just very unintuitive.
|
||||
# upstream bug - python/cpython#126886
|
||||
hide: bool = False
|
||||
alias: Optional[str] = None
|
||||
alias: str | None = None
|
||||
|
||||
def __init__(self, config: _Config) -> None:
|
||||
self.default = config.default
|
||||
@ -347,7 +347,7 @@ class ConfigModule(ModuleType):
|
||||
_bypass_keys: set[str]
|
||||
_compile_ignored_keys: set[str]
|
||||
_is_dirty: bool
|
||||
_hash_digest: Optional[bytes]
|
||||
_hash_digest: bytes | None
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
@ -411,7 +411,7 @@ class ConfigModule(ModuleType):
|
||||
|
||||
def _get_alias_module_and_name(
|
||||
self, entry: _ConfigEntry
|
||||
) -> Optional[tuple[ModuleType, str]]:
|
||||
) -> tuple[ModuleType, str] | None:
|
||||
alias = entry.alias
|
||||
if alias is None:
|
||||
return None
|
||||
@ -465,8 +465,8 @@ class ConfigModule(ModuleType):
|
||||
|
||||
def _get_dict(
|
||||
self,
|
||||
ignored_keys: Optional[list[str]] = None,
|
||||
ignored_prefixes: Optional[list[str]] = None,
|
||||
ignored_keys: list[str] | None = None,
|
||||
ignored_prefixes: list[str] | None = None,
|
||||
skip_default: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Export a dictionary of current configuration keys and values.
|
||||
@ -542,7 +542,7 @@ class ConfigModule(ModuleType):
|
||||
if module_name:
|
||||
imports.add(module_name)
|
||||
|
||||
def list_of_callables_to_string(v: Union[list, set]) -> list[str]:
|
||||
def list_of_callables_to_string(v: list | set) -> list[str]:
|
||||
return [f"{get_module_name(item, True)}{item.__name__}" for item in v]
|
||||
|
||||
def importable_callable(v: Any) -> bool:
|
||||
@ -615,7 +615,7 @@ class ConfigModule(ModuleType):
|
||||
def shallow_copy_dict(self) -> dict[str, Any]:
|
||||
return self.get_config_copy()
|
||||
|
||||
def load_config(self, maybe_pickled_config: Union[bytes, dict[str, Any]]) -> None:
|
||||
def load_config(self, maybe_pickled_config: bytes | dict[str, Any]) -> None:
|
||||
"""Restore from a prior call to save_config() or shallow_copy_dict()"""
|
||||
if not isinstance(maybe_pickled_config, dict):
|
||||
config = pickle.loads(maybe_pickled_config)
|
||||
@ -637,7 +637,7 @@ class ConfigModule(ModuleType):
|
||||
|
||||
def patch(
|
||||
self,
|
||||
arg1: Optional[Union[str, dict[str, Any]]] = None,
|
||||
arg1: str | dict[str, Any] | None = None,
|
||||
arg2: Any = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> "ContextDecorator":
|
||||
@ -816,7 +816,7 @@ def patch_object(obj: object, name: str, value: object) -> object:
|
||||
return mock.patch.object(obj, name, value)
|
||||
|
||||
|
||||
def get_tristate_env(name: str, default: Any = None) -> Optional[bool]:
|
||||
def get_tristate_env(name: str, default: Any = None) -> bool | None:
|
||||
value = os.environ.get(name)
|
||||
if value == "1":
|
||||
return True
|
||||
|
||||
@ -34,7 +34,6 @@ import hashlib
|
||||
import os.path
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch._prims as prims
|
||||
@ -193,9 +192,9 @@ class ContentStoreWriter:
|
||||
class ContentStoreReader:
|
||||
def __init__(self, loc: str, *, cache=True) -> None:
|
||||
self.loc = loc
|
||||
self.storage_cache: Optional[
|
||||
dict[Optional[torch.device], dict[str, StorageWeakRef]]
|
||||
] = None
|
||||
self.storage_cache: (
|
||||
dict[torch.device | None, dict[str, StorageWeakRef]] | None
|
||||
) = None
|
||||
if cache:
|
||||
self.storage_cache = defaultdict(dict)
|
||||
|
||||
@ -207,7 +206,7 @@ class ContentStoreReader:
|
||||
if self.storage_cache is not None
|
||||
else None
|
||||
)
|
||||
s: Optional[torch.UntypedStorage]
|
||||
s: torch.UntypedStorage | None
|
||||
if ws is not None:
|
||||
s = torch.UntypedStorage._new_with_weak_ptr(ws.cdata)
|
||||
if s is not None:
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from re import match as _match
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
def read_file(fname: Union[Path, str]) -> list[str]:
|
||||
def read_file(fname: Path | str) -> list[str]:
|
||||
with open(fname, encoding="utf-8") as f:
|
||||
return f.readlines()
|
||||
|
||||
@ -36,7 +35,7 @@ def _embed_headers(
|
||||
|
||||
|
||||
def embed_headers(
|
||||
fname: str, include_dirs: Optional[Union[Sequence[str], Sequence[Path], str]] = None
|
||||
fname: str, include_dirs: Sequence[str] | Sequence[Path] | str | None = None
|
||||
) -> str:
|
||||
if include_dirs is None:
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
|
||||
@ -15,7 +15,7 @@ collection support for PyTorch APIs.
|
||||
import functools
|
||||
import types
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import Any, Optional, overload, TypeAlias, TypeVar, Union
|
||||
from typing import Any, overload, TypeAlias, TypeVar, Union
|
||||
from typing_extensions import deprecated, Self, TypeIs
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
@ -128,10 +128,10 @@ def register_pytree_node(
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
||||
serialized_type_name: str | None = None,
|
||||
to_dumpable_context: ToDumpableContextFn | None = None,
|
||||
from_dumpable_context: FromDumpableContextFn | None = None,
|
||||
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node.
|
||||
|
||||
@ -196,9 +196,9 @@ def _register_pytree_node(
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
serialized_type_name: str | None = None,
|
||||
to_dumpable_context: ToDumpableContextFn | None = None,
|
||||
from_dumpable_context: FromDumpableContextFn | None = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node for the C++ pytree only.
|
||||
|
||||
@ -247,9 +247,9 @@ def _private_register_pytree_node(
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
serialized_type_name: str | None = None,
|
||||
to_dumpable_context: ToDumpableContextFn | None = None,
|
||||
from_dumpable_context: FromDumpableContextFn | None = None,
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
for the C++ pytree only. End-users should use :func:`register_pytree_node`
|
||||
@ -281,7 +281,7 @@ def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
|
||||
|
||||
|
||||
def treespec_dict(
|
||||
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
|
||||
mapping: Mapping[Any, TreeSpec] | Iterable[tuple[Any, TreeSpec]] = (),
|
||||
/,
|
||||
**kwargs: TreeSpec,
|
||||
) -> TreeSpec:
|
||||
@ -296,7 +296,7 @@ def treespec_dict(
|
||||
|
||||
def tree_is_leaf(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
"""Check if a pytree is a leaf.
|
||||
|
||||
@ -334,7 +334,7 @@ def tree_is_leaf(
|
||||
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> tuple[list[Any], TreeSpec]:
|
||||
"""Flatten a pytree.
|
||||
|
||||
@ -399,7 +399,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> Iterable[Any]:
|
||||
"""Get an iterator over the leaves of a pytree.
|
||||
|
||||
@ -434,7 +434,7 @@ def tree_iter(
|
||||
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Get the leaves of a pytree.
|
||||
|
||||
@ -469,7 +469,7 @@ def tree_leaves(
|
||||
|
||||
def tree_structure(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> TreeSpec:
|
||||
"""Get the treespec for a pytree.
|
||||
|
||||
@ -506,7 +506,7 @@ def tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
"""Map a multi-input function over pytree args to produce a new pytree.
|
||||
|
||||
@ -555,7 +555,7 @@ def tree_map_(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
|
||||
|
||||
@ -593,8 +593,8 @@ Type2 = tuple[type[T], type[S]]
|
||||
Type3 = tuple[type[T], type[S], type[U]]
|
||||
TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
|
||||
|
||||
Fn2 = Callable[[Union[T, S]], R]
|
||||
Fn3 = Callable[[Union[T, S, U]], R]
|
||||
Fn2 = Callable[[T | S], R]
|
||||
Fn3 = Callable[[T | S | U], R]
|
||||
Fn = Callable[[T], R]
|
||||
FnAny = Callable[[Any], R]
|
||||
|
||||
@ -629,7 +629,7 @@ def map_only(
|
||||
|
||||
|
||||
def map_only(
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], /
|
||||
type_or_types_or_pred: TypeAny | Callable[[Any], bool], /
|
||||
) -> MapOnlyFn[FnAny[Any]]:
|
||||
"""
|
||||
Suppose you are writing a tree_map over tensors, leaving everything
|
||||
@ -677,7 +677,7 @@ def tree_map_only(
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -687,7 +687,7 @@ def tree_map_only(
|
||||
/,
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -697,7 +697,7 @@ def tree_map_only(
|
||||
/,
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -707,7 +707,7 @@ def tree_map_only(
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -717,16 +717,16 @@ def tree_map_only(
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
def tree_map_only(
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
type_or_types_or_pred: TypeAny | Callable[[Any], bool],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
@ -737,7 +737,7 @@ def tree_map_only_(
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -747,7 +747,7 @@ def tree_map_only_(
|
||||
/,
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -757,7 +757,7 @@ def tree_map_only_(
|
||||
/,
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -767,7 +767,7 @@ def tree_map_only_(
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -777,16 +777,16 @@ def tree_map_only_(
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
def tree_map_only_(
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
type_or_types_or_pred: TypeAny | Callable[[Any], bool],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
@ -794,7 +794,7 @@ def tree_map_only_(
|
||||
def tree_all(
|
||||
pred: Callable[[Any], bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(map(pred, flat_args))
|
||||
@ -803,7 +803,7 @@ def tree_all(
|
||||
def tree_any(
|
||||
pred: Callable[[Any], bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(map(pred, flat_args))
|
||||
@ -815,7 +815,7 @@ def tree_all_only(
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -825,7 +825,7 @@ def tree_all_only(
|
||||
/,
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -835,7 +835,7 @@ def tree_all_only(
|
||||
/,
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -844,7 +844,7 @@ def tree_all_only(
|
||||
/,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(pred(x) for x in flat_args if isinstance(x, type_or_types))
|
||||
@ -856,7 +856,7 @@ def tree_any_only(
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -866,7 +866,7 @@ def tree_any_only(
|
||||
/,
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -876,7 +876,7 @@ def tree_any_only(
|
||||
/,
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -885,7 +885,7 @@ def tree_any_only(
|
||||
/,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(pred(x) for x in flat_args if isinstance(x, type_or_types))
|
||||
@ -894,7 +894,7 @@ def tree_any_only(
|
||||
def broadcast_prefix(
|
||||
prefix_tree: PyTree,
|
||||
full_tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
|
||||
|
||||
@ -956,8 +956,8 @@ def broadcast_prefix(
|
||||
def _broadcast_to_and_flatten(
|
||||
tree: PyTree,
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[list[Any]]:
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> list[Any] | None:
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise AssertionError(
|
||||
f"_broadcast_to_and_flatten: Expected `treespec` to be instance of PyTreeSpec but got {type(treespec)}"
|
||||
@ -969,7 +969,7 @@ def _broadcast_to_and_flatten(
|
||||
return None
|
||||
|
||||
|
||||
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str:
|
||||
"""Serialize a treespec to a JSON string."""
|
||||
if not _is_pytreespec_instance(treespec):
|
||||
raise TypeError(
|
||||
@ -1024,7 +1024,7 @@ class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): # type: ignore[misc,final]
|
||||
|
||||
def tree_flatten_with_path(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]:
|
||||
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
|
||||
|
||||
@ -1047,7 +1047,7 @@ def tree_flatten_with_path(
|
||||
|
||||
def tree_leaves_with_path(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> list[tuple[KeyPath, Any]]:
|
||||
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
||||
|
||||
@ -1070,7 +1070,7 @@ def tree_map_with_path(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
"""Like :func:`tree_map`, but the provided callable takes an additional key path argument.
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import functools
|
||||
import traceback
|
||||
import weakref
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
@ -140,7 +140,7 @@ def _get_stack_trace() -> str:
|
||||
return "".join(summary.format())
|
||||
|
||||
|
||||
def _maybe_get_autograd_trace() -> Optional[str]:
|
||||
def _maybe_get_autograd_trace() -> str | None:
|
||||
if torch._C._current_autograd_node() is not None:
|
||||
tb = torch._C._current_autograd_node().metadata.get("traceback_") # type: ignore[attr-defined]
|
||||
if tb:
|
||||
@ -154,8 +154,8 @@ class _DebugCall:
|
||||
def __init__(
|
||||
self,
|
||||
call_depth: int,
|
||||
record: Optional[dict[str, Any]] = None,
|
||||
log: Optional[dict[str, Any]] = None,
|
||||
record: dict[str, Any] | None = None,
|
||||
log: dict[str, Any] | None = None,
|
||||
stack: bool = False,
|
||||
) -> None:
|
||||
self.call_depth = call_depth
|
||||
@ -166,10 +166,10 @@ class _DebugCall:
|
||||
# results from dispatch hooks
|
||||
self.record = record
|
||||
self.log = log
|
||||
self.output_str: Optional[str] = None
|
||||
self.output_str: str | None = None
|
||||
|
||||
def stringify_args(
|
||||
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
|
||||
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
|
||||
) -> None:
|
||||
"""
|
||||
To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs.
|
||||
@ -182,7 +182,7 @@ class _DebugCall:
|
||||
self,
|
||||
output: Any,
|
||||
attributes: list[str],
|
||||
tensor_memo: Optional[TensorIdTracker] = None,
|
||||
tensor_memo: TensorIdTracker | None = None,
|
||||
) -> None:
|
||||
"""Store stringified version of call output in self.output_str"""
|
||||
if tree_all(lambda x: x is None, output):
|
||||
@ -213,11 +213,11 @@ class _OpCall(_DebugCall):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.args_str: Optional[str] = None
|
||||
self.kwargs_str: Optional[str] = None
|
||||
self.args_str: str | None = None
|
||||
self.kwargs_str: str | None = None
|
||||
|
||||
def stringify_args(
|
||||
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
|
||||
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
|
||||
) -> None:
|
||||
self.args_str = ", ".join(
|
||||
_arg_to_str(arg, attributes, tensor_memo) for arg in self.args
|
||||
@ -289,10 +289,10 @@ class _RedistributeCall(_DebugCall):
|
||||
self.dst_placement = dst_placement
|
||||
self.transform_info_str = transform_info_str
|
||||
|
||||
self.arg_str: Optional[str] = None
|
||||
self.arg_str: str | None = None
|
||||
|
||||
def stringify_args(
|
||||
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
|
||||
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
|
||||
) -> None:
|
||||
self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}"
|
||||
del self.arg
|
||||
@ -339,7 +339,7 @@ class _NNModuleCall(_DebugCall):
|
||||
self.module_name = module_name
|
||||
|
||||
def stringify_args(
|
||||
self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None
|
||||
self, attributes: list[str], tensor_memo: TensorIdTracker | None = None
|
||||
) -> None:
|
||||
pass # nothing to stringify
|
||||
|
||||
@ -418,7 +418,7 @@ class DebugMode(TorchDispatchMode):
|
||||
# This flag currently has no effect on torch.compiled-regions.
|
||||
self.record_nn_module = record_nn_module
|
||||
|
||||
self.module_tracker: Optional[ModTracker] = None
|
||||
self.module_tracker: ModTracker | None = None
|
||||
if self.record_nn_module:
|
||||
self.module_tracker_setup()
|
||||
|
||||
@ -585,7 +585,7 @@ class DebugMode(TorchDispatchMode):
|
||||
arg,
|
||||
src_placement,
|
||||
dst_placement,
|
||||
transform_info_str: Optional[str] = None,
|
||||
transform_info_str: str | None = None,
|
||||
):
|
||||
try:
|
||||
self._record_call(
|
||||
@ -615,8 +615,8 @@ class DebugMode(TorchDispatchMode):
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def dispatch_hooks(
|
||||
record_hook: Optional[Callable] = None,
|
||||
log_hook: Optional[Callable] = None,
|
||||
record_hook: Callable | None = None,
|
||||
log_hook: Callable | None = None,
|
||||
):
|
||||
"""
|
||||
Allows installing post-hooks on arguments to intercepted __torch_dispatch__ calls;
|
||||
@ -660,9 +660,7 @@ class DebugMode(TorchDispatchMode):
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def log_tensor_hashes(
|
||||
hash_fn: Optional[Callable] = None, hash_inputs: bool = False
|
||||
):
|
||||
def log_tensor_hashes(hash_fn: Callable | None = None, hash_inputs: bool = False):
|
||||
"""
|
||||
Installs hook for tensor hash logging.
|
||||
|
||||
@ -696,7 +694,7 @@ class DebugMode(TorchDispatchMode):
|
||||
yield
|
||||
|
||||
|
||||
def get_active_debug_mode() -> Optional[DebugMode]:
|
||||
def get_active_debug_mode() -> DebugMode | None:
|
||||
debug_mode = None
|
||||
for mode in _get_current_dispatch_mode_stack():
|
||||
if isinstance(mode, DebugMode):
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._C import _len_torch_function_stack
|
||||
@ -8,7 +7,7 @@ from torch.overrides import _pop_mode, _push_mode, TorchFunctionMode
|
||||
from torch.utils._contextlib import context_decorator
|
||||
|
||||
|
||||
CURRENT_DEVICE: Optional[torch.device] = None
|
||||
CURRENT_DEVICE: torch.device | None = None
|
||||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from types import TracebackType
|
||||
from typing import Optional
|
||||
from typing_extensions import Self
|
||||
|
||||
from filelock import FileLock as base_FileLock
|
||||
@ -28,9 +27,9 @@ class FileLock(base_FileLock):
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.region_counter.__exit__()
|
||||
with _WaitCounter("pytorch.filelock.exit").guard():
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, TypeAlias
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -23,7 +23,7 @@ def _get_fused_kernels_supported_devices() -> list[str]:
|
||||
]
|
||||
|
||||
|
||||
TensorListList: TypeAlias = list[list[Optional[Tensor]]]
|
||||
TensorListList: TypeAlias = list[list[Tensor | None]]
|
||||
Indices: TypeAlias = list[int]
|
||||
_foreach_supported_types = [torch.Tensor]
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import functools
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _check_module_exists(name: str) -> bool:
|
||||
@ -24,7 +23,7 @@ def dill_available() -> bool:
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def import_dill() -> Optional[ModuleType]:
|
||||
def import_dill() -> ModuleType | None:
|
||||
if not dill_available():
|
||||
return None
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from collections.abc import (
|
||||
Reversible,
|
||||
Set as AbstractSet,
|
||||
)
|
||||
from typing import Any, cast, Optional, TypeVar
|
||||
from typing import Any, cast, TypeVar
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Hashable)
|
||||
@ -24,7 +24,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
|
||||
|
||||
__slots__ = ("_dict",)
|
||||
|
||||
def __init__(self, iterable: Optional[Iterable[T]] = None) -> None:
|
||||
def __init__(self, iterable: Iterable[T] | None = None) -> None:
|
||||
self._dict = dict.fromkeys(iterable, None) if iterable is not None else {}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -6,7 +6,7 @@ import functools
|
||||
import warnings
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, Optional, overload, Protocol, TYPE_CHECKING, Union
|
||||
from typing import cast, overload, Protocol, TYPE_CHECKING
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
@ -207,7 +207,7 @@ class TorchDispatchMode:
|
||||
return False
|
||||
|
||||
|
||||
def _get_current_dispatch_mode() -> Optional[TorchDispatchMode]:
|
||||
def _get_current_dispatch_mode() -> TorchDispatchMode | None:
|
||||
"""
|
||||
Return the top user mode on the stack (the next one that would be
|
||||
executed) if there are any.
|
||||
@ -308,7 +308,7 @@ def _push_mode(mode: TorchDispatchMode) -> None:
|
||||
_set_mode_pre_dispatch(mode)
|
||||
|
||||
|
||||
def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None):
|
||||
def _pop_mode(k: DispatchKey | torch._C._TorchDispatchModeKey | None = None):
|
||||
if k == torch._C.DispatchKey.PreDispatch: # type: ignore[attr-defined]
|
||||
from torch._ops import _pop_mode_from_pre_dispatch
|
||||
|
||||
@ -319,7 +319,7 @@ def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] =
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _pop_mode_temporarily(k: Optional[DispatchKey] = None):
|
||||
def _pop_mode_temporarily(k: DispatchKey | None = None):
|
||||
old = _pop_mode(k)
|
||||
try:
|
||||
yield old
|
||||
@ -429,18 +429,18 @@ class TensorWithFlatten(Protocol):
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
memory_format: torch.memory_format | None = None,
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch._prims_common.DeviceLikeType] = None,
|
||||
dtype: Optional[torch.types._dtype] = None,
|
||||
device: torch._prims_common.DeviceLikeType | None = None,
|
||||
dtype: torch.types._dtype | None = None,
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
memory_format: torch.memory_format | None = None,
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
@overload
|
||||
@ -450,7 +450,7 @@ class TensorWithFlatten(Protocol):
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
*,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
memory_format: torch.memory_format | None = None,
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
@ -610,7 +610,7 @@ def _correct_storage_aliasing(func, schema_info, args, outs) -> None:
|
||||
alias_non_inplace_storage(args[arg_idx], outs[return_idx])
|
||||
|
||||
|
||||
def _get_write_alias(x) -> Optional[str]:
|
||||
def _get_write_alias(x) -> str | None:
|
||||
alias_set = x.alias_set
|
||||
if not alias_set or not x.is_write:
|
||||
return None
|
||||
@ -629,7 +629,7 @@ def _get_write_alias(x) -> Optional[str]:
|
||||
class AliasInfo:
|
||||
alias_set: set[str]
|
||||
is_write: bool
|
||||
name: Optional[str]
|
||||
name: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -642,7 +642,7 @@ class SchemaInfo:
|
||||
# [_get_write_alias(x) for x in outs]. Guaranteed to contain no Nones; we coerce
|
||||
# all-Nones result to empty list instead, and we don't support
|
||||
# some-but-not-all-Nones.
|
||||
outs_write_aliases: Optional[list[str]]
|
||||
outs_write_aliases: list[str] | None
|
||||
|
||||
# List of (arg_idx, return_idx) where args[arg_idx].alias_set &
|
||||
# outs[out_idx].alias_set is not empty, and not args[arg_idx].is_write.
|
||||
@ -726,12 +726,12 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
if is_read_only_alias_match:
|
||||
read_only_alias_match_indexes.append((arg_idx, return_idx))
|
||||
|
||||
outs_write_aliases_list: list[Optional[str]] = [
|
||||
outs_write_aliases_list: list[str | None] = [
|
||||
_get_write_alias(r) for r in out_schemas
|
||||
]
|
||||
non_nones = sum(x is not None for x in outs_write_aliases_list)
|
||||
if non_nones == 0:
|
||||
outs_write_aliases: Optional[list[str]] = None
|
||||
outs_write_aliases: list[str] | None = None
|
||||
elif non_nones != len(outs_write_aliases_list):
|
||||
# simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
|
||||
raise RuntimeError("Unsupported schema: " + str(func._schema))
|
||||
@ -751,7 +751,7 @@ def get_alias_info(func) -> SchemaInfo:
|
||||
|
||||
|
||||
def autograd_would_have_decomposed(
|
||||
func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]]
|
||||
func: torch._ops.OpOverload, flat_args: Sequence[torch.Tensor | object]
|
||||
) -> bool:
|
||||
"""
|
||||
Suppose that an operator has CompositeImplicitAutograd decomp registered.
|
||||
|
||||
@ -33,7 +33,6 @@ from typing import (
|
||||
Final,
|
||||
Generic,
|
||||
NoReturn,
|
||||
Optional,
|
||||
overload,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
@ -109,7 +108,7 @@ class KeyEntry(Protocol):
|
||||
|
||||
|
||||
class EnumEncoder(json.JSONEncoder):
|
||||
def default(self, obj: object) -> Union[str, dict[str, Any]]:
|
||||
def default(self, obj: object) -> str | dict[str, Any]:
|
||||
if isinstance(obj, Enum):
|
||||
return {
|
||||
"__enum__": True,
|
||||
@ -127,7 +126,7 @@ DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
ToStrFunc = Callable[["TreeSpec", list[str]], str]
|
||||
MaybeFromStrFunc = Callable[[str], Optional[tuple[Any, Context, str]]]
|
||||
MaybeFromStrFunc = Callable[[str], tuple[Any, Context, str] | None]
|
||||
KeyPath = tuple[KeyEntry, ...]
|
||||
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
|
||||
|
||||
@ -145,7 +144,7 @@ class NodeDef(NamedTuple):
|
||||
type: type[Any]
|
||||
flatten_fn: FlattenFunc
|
||||
unflatten_fn: UnflattenFunc
|
||||
flatten_with_keys_fn: Optional[FlattenWithKeysFunc]
|
||||
flatten_with_keys_fn: FlattenWithKeysFunc | None
|
||||
|
||||
|
||||
_NODE_REGISTRY_LOCK = threading.RLock()
|
||||
@ -162,8 +161,8 @@ SUPPORTED_NODES: dict[type[Any], NodeDef] = {}
|
||||
class _SerializeNodeDef(NamedTuple):
|
||||
typ: type[Any]
|
||||
serialized_type_name: str
|
||||
to_dumpable_context: Optional[ToDumpableContextFn]
|
||||
from_dumpable_context: Optional[FromDumpableContextFn]
|
||||
to_dumpable_context: ToDumpableContextFn | None
|
||||
from_dumpable_context: FromDumpableContextFn | None
|
||||
|
||||
|
||||
SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {}
|
||||
@ -199,10 +198,10 @@ def register_pytree_node(
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
||||
serialized_type_name: str | None = None,
|
||||
to_dumpable_context: ToDumpableContextFn | None = None,
|
||||
from_dumpable_context: FromDumpableContextFn | None = None,
|
||||
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node.
|
||||
|
||||
@ -273,9 +272,9 @@ def register_pytree_node(
|
||||
def register_dataclass(
|
||||
cls: type[Any],
|
||||
*,
|
||||
field_names: Optional[list[str]] = None,
|
||||
drop_field_names: Optional[list[str]] = None,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
field_names: list[str] | None = None,
|
||||
drop_field_names: list[str] | None = None,
|
||||
serialized_type_name: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Registers a type that has the semantics of a ``dataclasses.dataclass`` type
|
||||
@ -524,13 +523,13 @@ def _register_pytree_node(
|
||||
cls: type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
to_str_fn: Optional[ToStrFunc] = None, # deprecated
|
||||
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated
|
||||
to_str_fn: ToStrFunc | None = None, # deprecated
|
||||
maybe_from_str_fn: MaybeFromStrFunc | None = None, # deprecated
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
||||
serialized_type_name: str | None = None,
|
||||
to_dumpable_context: ToDumpableContextFn | None = None,
|
||||
from_dumpable_context: FromDumpableContextFn | None = None,
|
||||
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
|
||||
) -> None:
|
||||
"""Register a container-like type as pytree node for the Python pytree only.
|
||||
|
||||
@ -594,10 +593,10 @@ def _private_register_pytree_node(
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
|
||||
serialized_type_name: str | None = None,
|
||||
to_dumpable_context: ToDumpableContextFn | None = None,
|
||||
from_dumpable_context: FromDumpableContextFn | None = None,
|
||||
flatten_with_keys_fn: FlattenWithKeysFunc | None = None,
|
||||
) -> None:
|
||||
"""This is an internal function that is used to register a pytree node type
|
||||
for the Python pytree only. End-users should use :func:`register_pytree_node`
|
||||
@ -671,7 +670,7 @@ class GetAttrKey:
|
||||
|
||||
|
||||
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
|
||||
def is_namedtuple(obj: Union[object, type]) -> bool:
|
||||
def is_namedtuple(obj: object | type) -> bool:
|
||||
"""Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
|
||||
cls = obj if isinstance(obj, type) else type(obj)
|
||||
return is_namedtuple_class(cls)
|
||||
@ -723,7 +722,7 @@ class structseq(tuple[_T_co, ...]):
|
||||
|
||||
|
||||
# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py
|
||||
def is_structseq(obj: Union[object, type]) -> bool:
|
||||
def is_structseq(obj: object | type) -> bool:
|
||||
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
|
||||
cls = obj if isinstance(obj, type) else type(obj)
|
||||
return is_structseq_class(cls)
|
||||
@ -1046,7 +1045,7 @@ def _get_node_type(tree: Any) -> Any:
|
||||
# A leaf is defined as anything that is not a Node.
|
||||
def tree_is_leaf(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
"""Check if a pytree is a leaf.
|
||||
|
||||
@ -1073,7 +1072,7 @@ def tree_is_leaf(
|
||||
"Please use torch.utils._pytree.tree_is_leaf instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool:
|
||||
def _is_leaf(tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None) -> bool:
|
||||
return tree_is_leaf(tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@ -1353,7 +1352,7 @@ def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec:
|
||||
|
||||
|
||||
def treespec_dict(
|
||||
mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (),
|
||||
mapping: Mapping[Any, TreeSpec] | Iterable[tuple[Any, TreeSpec]] = (),
|
||||
/,
|
||||
**kwargs: TreeSpec,
|
||||
) -> TreeSpec:
|
||||
@ -1366,7 +1365,7 @@ def treespec_dict(
|
||||
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> tuple[list[Any], TreeSpec]:
|
||||
"""Flattens a pytree into a list of values and a TreeSpec that can be used
|
||||
to reconstruct the pytree.
|
||||
@ -1404,7 +1403,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
|
||||
|
||||
def tree_iter(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> Iterable[Any]:
|
||||
"""Get an iterator over the leaves of a pytree."""
|
||||
if tree_is_leaf(tree, is_leaf=is_leaf):
|
||||
@ -1421,7 +1420,7 @@ def tree_iter(
|
||||
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Get a list of leaves of a pytree."""
|
||||
return list(tree_iter(tree, is_leaf=is_leaf))
|
||||
@ -1429,7 +1428,7 @@ def tree_leaves(
|
||||
|
||||
def tree_structure(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> TreeSpec:
|
||||
"""Get the TreeSpec for a pytree."""
|
||||
return tree_flatten(tree, is_leaf=is_leaf)[1]
|
||||
@ -1439,7 +1438,7 @@ def tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
"""Map a multi-input function over pytree args to produce a new pytree.
|
||||
|
||||
@ -1483,7 +1482,7 @@ def tree_map_(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
|
||||
|
||||
@ -1517,8 +1516,8 @@ Type2 = tuple[type[T], type[S]]
|
||||
Type3 = tuple[type[T], type[S], type[U]]
|
||||
TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
|
||||
|
||||
Fn2 = Callable[[Union[T, S]], R]
|
||||
Fn3 = Callable[[Union[T, S, U]], R]
|
||||
Fn2 = Callable[[T | S], R]
|
||||
Fn3 = Callable[[T | S | U], R]
|
||||
Fn = Callable[[T], R]
|
||||
FnAny = Callable[[Any], R]
|
||||
|
||||
@ -1553,7 +1552,7 @@ def map_only(
|
||||
|
||||
|
||||
def map_only(
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], /
|
||||
type_or_types_or_pred: TypeAny | Callable[[Any], bool], /
|
||||
) -> MapOnlyFn[FnAny[Any]]:
|
||||
"""
|
||||
Suppose you are writing a tree_map over tensors, leaving everything
|
||||
@ -1601,7 +1600,7 @@ def tree_map_only(
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -1611,7 +1610,7 @@ def tree_map_only(
|
||||
/,
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -1621,7 +1620,7 @@ def tree_map_only(
|
||||
/,
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -1631,7 +1630,7 @@ def tree_map_only(
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -1641,16 +1640,16 @@ def tree_map_only(
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
def tree_map_only(
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
type_or_types_or_pred: TypeAny | Callable[[Any], bool],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
@ -1661,7 +1660,7 @@ def tree_map_only_(
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -1671,7 +1670,7 @@ def tree_map_only_(
|
||||
/,
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -1681,7 +1680,7 @@ def tree_map_only_(
|
||||
/,
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -1691,7 +1690,7 @@ def tree_map_only_(
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
@ -1701,16 +1700,16 @@ def tree_map_only_(
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree: ...
|
||||
|
||||
|
||||
def tree_map_only_(
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
type_or_types_or_pred: TypeAny | Callable[[Any], bool],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
@ -1718,7 +1717,7 @@ def tree_map_only_(
|
||||
def tree_all(
|
||||
pred: Callable[[Any], bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(map(pred, flat_args))
|
||||
@ -1727,7 +1726,7 @@ def tree_all(
|
||||
def tree_any(
|
||||
pred: Callable[[Any], bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(map(pred, flat_args))
|
||||
@ -1739,7 +1738,7 @@ def tree_all_only(
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -1749,7 +1748,7 @@ def tree_all_only(
|
||||
/,
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -1759,7 +1758,7 @@ def tree_all_only(
|
||||
/,
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -1768,7 +1767,7 @@ def tree_all_only(
|
||||
/,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(pred(x) for x in flat_args if isinstance(x, type_or_types))
|
||||
@ -1780,7 +1779,7 @@ def tree_any_only(
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -1790,7 +1789,7 @@ def tree_any_only(
|
||||
/,
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -1800,7 +1799,7 @@ def tree_any_only(
|
||||
/,
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -1809,7 +1808,7 @@ def tree_any_only(
|
||||
/,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(pred(x) for x in flat_args if isinstance(x, type_or_types))
|
||||
@ -1826,8 +1825,8 @@ def tree_any_only(
|
||||
def _broadcast_to_and_flatten(
|
||||
tree: PyTree,
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[list[Any]]:
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> list[Any] | None:
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise AssertionError("treespec must be a TreeSpec")
|
||||
|
||||
@ -1868,7 +1867,7 @@ class _TreeSpecSchema:
|
||||
- children_spec: A list of children serialized specs.
|
||||
"""
|
||||
|
||||
type: Optional[str]
|
||||
type: str | None
|
||||
context: DumpableContext
|
||||
children_spec: list["_TreeSpecSchema"]
|
||||
|
||||
@ -1917,7 +1916,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
|
||||
return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
|
||||
|
||||
|
||||
def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
|
||||
def enum_object_hook(obj: dict[str, Any]) -> Enum | dict[str, Any]:
|
||||
if "__enum__" in obj:
|
||||
modname, _, classname = obj["fqn"].partition(":")
|
||||
mod = importlib.import_module(modname)
|
||||
@ -1968,7 +1967,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
|
||||
_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
|
||||
|
||||
|
||||
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
|
||||
def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str:
|
||||
if not isinstance(treespec, TreeSpec):
|
||||
raise TypeError(
|
||||
f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
|
||||
@ -2048,7 +2047,7 @@ def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> list[Any]:
|
||||
|
||||
def tree_flatten_with_path(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]:
|
||||
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
|
||||
|
||||
@ -2072,7 +2071,7 @@ def tree_flatten_with_path(
|
||||
|
||||
def tree_leaves_with_path(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> list[tuple[KeyPath, Any]]:
|
||||
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
||||
|
||||
@ -2094,7 +2093,7 @@ def tree_leaves_with_path(
|
||||
def _generate_key_paths(
|
||||
key_path: KeyPath,
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> Iterable[tuple[KeyPath, Any]]:
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
@ -2124,7 +2123,7 @@ def tree_map_with_path(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
*rests: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
) -> PyTree:
|
||||
"""Like :func:`tree_map`, but the provided callable takes an additional key path argument.
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import subprocess
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from threading import Lock
|
||||
from typing import Any, Optional, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
@ -34,14 +34,14 @@ class StrobelightCLIProfilerError(Exception):
|
||||
"""
|
||||
|
||||
|
||||
def _pid_namespace_link(pid: Optional[int] = None) -> str:
|
||||
def _pid_namespace_link(pid: int | None = None) -> str:
|
||||
"""Returns the link to the process's namespace, example: pid:[4026531836]"""
|
||||
PID_NAMESPACE_PATH = "/proc/{}/ns/pid"
|
||||
pid = pid or os.getpid()
|
||||
return os.readlink(PID_NAMESPACE_PATH.format(pid))
|
||||
|
||||
|
||||
def _pid_namespace(pid: Optional[int] = None) -> int:
|
||||
def _pid_namespace(pid: int | None = None) -> int:
|
||||
"""Returns the process's namespace id"""
|
||||
pid = pid or os.getpid()
|
||||
link = _pid_namespace_link(pid)
|
||||
@ -77,8 +77,8 @@ class StrobelightCLIFunctionProfiler:
|
||||
run_user_name: str = "pytorch-strobelight-ondemand",
|
||||
timeout_wait_for_running_sec: int = 60,
|
||||
timeout_wait_for_finished_sec: int = 60,
|
||||
recorded_env_variables: Optional[list[str]] = None,
|
||||
sample_tags: Optional[list[str]] = None,
|
||||
recorded_env_variables: list[str] | None = None,
|
||||
sample_tags: list[str] | None = None,
|
||||
stack_max_len: int = 127,
|
||||
async_stack_max_len: int = 127,
|
||||
) -> None:
|
||||
@ -90,7 +90,7 @@ class StrobelightCLIFunctionProfiler:
|
||||
self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec
|
||||
# Results of the most recent run.
|
||||
# Tracks the strobelight run id of the most recent run
|
||||
self.current_run_id: Optional[int] = None
|
||||
self.current_run_id: int | None = None
|
||||
self.sample_tags = sample_tags
|
||||
|
||||
def _run_async(self) -> None:
|
||||
@ -253,7 +253,7 @@ class StrobelightCLIFunctionProfiler:
|
||||
|
||||
def profile(
|
||||
self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
|
||||
) -> Optional[_R]:
|
||||
) -> _R | None:
|
||||
self.current_run_id = None
|
||||
|
||||
if locked := StrobelightCLIFunctionProfiler._lock.acquire(False):
|
||||
@ -295,16 +295,16 @@ class StrobelightCLIFunctionProfiler:
|
||||
# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..))
|
||||
# @strobelight(stop_at_error=True,...)
|
||||
def strobelight(
|
||||
profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]:
|
||||
profiler: StrobelightCLIFunctionProfiler | None = None, **kwargs: Any
|
||||
) -> Callable[[Callable[_P, _R]], Callable[_P, _R | None]]:
|
||||
if not profiler:
|
||||
profiler = StrobelightCLIFunctionProfiler(**kwargs)
|
||||
|
||||
def strobelight_inner(
|
||||
work_function: Callable[_P, _R],
|
||||
) -> Callable[_P, Optional[_R]]:
|
||||
) -> Callable[_P, _R | None]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _R | None:
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return profiler.profile(work_function, *args, **kwargs)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
import operator
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import SupportsFloat, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import TypeVarTuple, Unpack
|
||||
|
||||
import sympy
|
||||
@ -102,11 +102,11 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
|
||||
|
||||
def _keep_float(
|
||||
f: Callable[[Unpack[_Ts]], _T],
|
||||
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
|
||||
) -> Callable[[Unpack[_Ts]], _T | sympy.Float]:
|
||||
@functools.wraps(f)
|
||||
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
|
||||
def inner(*args: Unpack[_Ts]) -> _T | sympy.Float:
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
r: Union[_T, sympy.Float] = f(*args)
|
||||
r: _T | sympy.Float = f(*args)
|
||||
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
|
||||
r, sympy.Float
|
||||
):
|
||||
@ -117,7 +117,7 @@ def _keep_float(
|
||||
return inner
|
||||
|
||||
|
||||
def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]:
|
||||
def fuzzy_eq(x: bool | None, y: bool | None) -> bool | None:
|
||||
if None in (x, y):
|
||||
return None
|
||||
return x == y
|
||||
@ -216,9 +216,7 @@ class FloorDiv(sympy.Function):
|
||||
# Automatic evaluation.
|
||||
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
|
||||
@classmethod
|
||||
def eval(
|
||||
cls, base: sympy.Integer, divisor: sympy.Integer
|
||||
) -> Union[sympy.Basic, None]:
|
||||
def eval(cls, base: sympy.Integer, divisor: sympy.Integer) -> sympy.Basic | None:
|
||||
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
|
||||
# Assert triggered by inequality solver
|
||||
# assert base.is_integer, base
|
||||
@ -324,7 +322,7 @@ class ModularIndexing(sympy.Function):
|
||||
@classmethod
|
||||
def eval(
|
||||
cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer
|
||||
) -> Optional[sympy.Basic]:
|
||||
) -> sympy.Basic | None:
|
||||
if base == 0 or modulus == 1:
|
||||
return sympy.S.Zero
|
||||
if (
|
||||
@ -373,7 +371,7 @@ class ModularIndexing(sympy.Function):
|
||||
|
||||
return None
|
||||
|
||||
def _eval_is_nonnegative(self) -> Optional[bool]:
|
||||
def _eval_is_nonnegative(self) -> bool | None:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
p, q = self.args[:2]
|
||||
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
|
||||
@ -387,23 +385,21 @@ class Where(sympy.Function):
|
||||
nargs: tuple[int, ...] = (3,)
|
||||
precedence: int = 35 # lower precedence than add
|
||||
|
||||
def _eval_is_integer(self) -> Optional[bool]:
|
||||
def _eval_is_integer(self) -> bool | None:
|
||||
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_nonnegative(self) -> Optional[bool]:
|
||||
def _eval_is_nonnegative(self) -> bool | None:
|
||||
return (
|
||||
True
|
||||
if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
|
||||
else None
|
||||
)
|
||||
|
||||
def _eval_is_positive(self) -> Optional[bool]:
|
||||
def _eval_is_positive(self) -> bool | None:
|
||||
return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def eval(
|
||||
cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic
|
||||
) -> Optional[sympy.Basic]:
|
||||
def eval(cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic) -> sympy.Basic | None:
|
||||
if c == sympy.true:
|
||||
return p
|
||||
elif c == sympy.false:
|
||||
@ -419,7 +415,7 @@ class PythonMod(sympy.Function):
|
||||
is_integer: bool = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]:
|
||||
def eval(cls, p: sympy.Expr, q: sympy.Expr) -> sympy.Expr | None:
|
||||
# python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
|
||||
# Triggered by sympy.solvers.inequalities.reduce_inequalities
|
||||
# assert p.is_integer, p
|
||||
@ -465,10 +461,10 @@ class PythonMod(sympy.Function):
|
||||
return None
|
||||
|
||||
# NB: args[1] for PythonMod
|
||||
def _eval_is_nonnegative(self) -> Optional[bool]:
|
||||
def _eval_is_nonnegative(self) -> bool | None:
|
||||
return True if self.args[1].is_positive else None # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_nonpositive(self) -> Optional[bool]:
|
||||
def _eval_is_nonpositive(self) -> bool | None:
|
||||
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
|
||||
|
||||
def _ccode(self, printer) -> str:
|
||||
@ -664,7 +660,7 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
||||
@classmethod
|
||||
def _satisfy_unique_summations_symbols(
|
||||
cls, args
|
||||
) -> Optional[set[sympy.core.symbol.Symbol]]:
|
||||
) -> set[sympy.core.symbol.Symbol] | None:
|
||||
"""
|
||||
One common case in some models is building expressions of the form
|
||||
max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...).
|
||||
@ -719,8 +715,8 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
||||
|
||||
@classmethod
|
||||
def _unique_symbols(
|
||||
cls, args, initial_set: Optional[set[sympy.core.symbol.Symbol]] = None
|
||||
) -> Optional[set[sympy.core.symbol.Symbol]]:
|
||||
cls, args, initial_set: set[sympy.core.symbol.Symbol] | None = None
|
||||
) -> set[sympy.core.symbol.Symbol] | None:
|
||||
"""
|
||||
Return seen_symbols if all atoms in all args are all unique symbols,
|
||||
else returns None. initial_set can be used to represent initial value for seen_symbols
|
||||
|
||||
@ -10,7 +10,7 @@ of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import sympy
|
||||
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
|
||||
@ -184,7 +184,7 @@ _nil = object()
|
||||
def sympy_interp(
|
||||
analysis,
|
||||
env: dict[sympy.Symbol, Any],
|
||||
expr: Union[sympy.Expr, SympyBoolean],
|
||||
expr: sympy.Expr | SympyBoolean,
|
||||
*,
|
||||
index_dtype=torch.int64,
|
||||
missing_handler=None,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import sympy
|
||||
from sympy.printing.precedence import PRECEDENCE, precedence
|
||||
@ -23,7 +22,7 @@ class ExprPrinter(StrPrinter):
|
||||
def _print_Not(self, expr: sympy.Expr) -> str:
|
||||
return f"not ({self._print(expr.args[0])})"
|
||||
|
||||
def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
|
||||
def _print_Add(self, expr: sympy.Expr, order: str | None = None) -> str:
|
||||
return self.stringify(expr.args, " + ", precedence(expr))
|
||||
|
||||
def _print_Relational(self, expr: sympy.Expr) -> str:
|
||||
@ -310,7 +309,7 @@ class PythonPrinter(ExprPrinter):
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
|
||||
result: Optional[str] = None
|
||||
result: str | None = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self._print(expr_i)
|
||||
if cond_i == True: # noqa: E712
|
||||
@ -349,7 +348,7 @@ class CppPrinter(ExprPrinter):
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
|
||||
result: Optional[str] = None
|
||||
result: str | None = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
|
||||
if cond_i == True: # noqa: E712
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
import operator
|
||||
from typing import NoReturn, Union
|
||||
from typing import NoReturn
|
||||
|
||||
import sympy
|
||||
|
||||
@ -359,7 +359,7 @@ class TensorReferenceAnalysis:
|
||||
# function isn't traced correctly. Here for completeness.
|
||||
@staticmethod
|
||||
def constant(c, dtype):
|
||||
d: Union[int, float, bool]
|
||||
d: int | float | bool
|
||||
if dtype is torch.int64:
|
||||
d = int(c)
|
||||
elif dtype is torch.double:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import sympy
|
||||
|
||||
@ -20,7 +19,7 @@ _MIRROR_REL_OP: dict[type[sympy.Basic], type[sympy.Rel]] = {
|
||||
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
|
||||
|
||||
|
||||
def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]:
|
||||
def mirror_rel_op(type: type) -> type[sympy.Rel] | None:
|
||||
return _MIRROR_REL_OP.get(type)
|
||||
|
||||
|
||||
@ -43,7 +42,7 @@ def try_solve(
|
||||
thing: sympy.Basic,
|
||||
trials: int = 5,
|
||||
floordiv_inequality: bool = True,
|
||||
) -> Optional[tuple[sympy.Rel, sympy.Expr]]:
|
||||
) -> tuple[sympy.Rel, sympy.Expr] | None:
|
||||
mirror = mirror_rel_op(type(expr))
|
||||
|
||||
# Ignore unsupported expressions:
|
||||
|
||||
@ -14,7 +14,6 @@ in this file and seeing what breaks.
|
||||
|
||||
from collections.abc import Iterable
|
||||
from enum import auto, Enum
|
||||
from typing import Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -88,7 +87,7 @@ def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
|
||||
|
||||
# This type is a little wider than it should be, because free_symbols says
|
||||
# that it contains Basic, rather than Symbol
|
||||
def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> bool:
|
||||
def symbol_is_type(sym: sympy.Basic, prefix: SymT | Iterable[SymT]) -> bool:
|
||||
if not isinstance(sym, sympy.Symbol):
|
||||
raise AssertionError("expected sympy.Symbol")
|
||||
name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
|
||||
@ -98,5 +97,5 @@ def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> boo
|
||||
return name_str.startswith(tuple(prefix_str[p] for p in prefix))
|
||||
|
||||
|
||||
def free_symbol_is_type(e: sympy.Expr, prefix: Union[SymT, Iterable[SymT]]) -> bool:
|
||||
def free_symbol_is_type(e: sympy.Expr, prefix: SymT | Iterable[SymT]) -> bool:
|
||||
return any(symbol_is_type(v, prefix) for v in e.free_symbols)
|
||||
|
||||
@ -10,7 +10,6 @@ import operator
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Generic,
|
||||
Optional,
|
||||
overload,
|
||||
SupportsFloat,
|
||||
TYPE_CHECKING,
|
||||
@ -325,16 +324,16 @@ class ValueRanges(Generic[_T]):
|
||||
@overload
|
||||
@staticmethod
|
||||
# work around the fact that bool and int overlap
|
||||
def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap]
|
||||
def wrap(arg: ExprIn | ExprVR) -> ExprVR: # type: ignore[overload-overlap]
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR: # type: ignore[misc]
|
||||
def wrap(arg: BoolIn | BoolVR) -> BoolVR: # type: ignore[misc]
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def wrap(arg: Union[AllIn, AllVR]) -> AllVR:
|
||||
def wrap(arg: AllIn | AllVR) -> AllVR:
|
||||
if isinstance(arg, ValueRanges):
|
||||
return arg
|
||||
if isinstance(arg, float) and math.isnan(arg):
|
||||
@ -343,29 +342,29 @@ class ValueRanges(Generic[_T]):
|
||||
return ValueRanges(arg, arg) # type: ignore[arg-type]
|
||||
|
||||
@staticmethod
|
||||
def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
|
||||
def increasing_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR:
|
||||
"""Increasing: x <= y => f(x) <= f(y)."""
|
||||
x = ValueRanges.wrap(x)
|
||||
return ValueRanges(fn(x.lower), fn(x.upper))
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ...
|
||||
def decreasing_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: ...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR: # type: ignore[misc]
|
||||
def decreasing_map(x: BoolIn | BoolVR, fn: BoolFn) -> BoolVR: # type: ignore[misc]
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR:
|
||||
def decreasing_map(x: AllIn | AllVR, fn: AllFn) -> AllVR:
|
||||
"""Decreasing: x <= y => f(x) >= f(y)."""
|
||||
x = ValueRanges.wrap(x)
|
||||
# consistently either Expr or Bool, but we don't know it here
|
||||
return ValueRanges(fn(x.upper), fn(x.lower)) # type: ignore[arg-type]
|
||||
|
||||
@staticmethod
|
||||
def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
|
||||
def monotone_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR:
|
||||
"""It's increasing or decreasing."""
|
||||
x = ValueRanges.wrap(x)
|
||||
l = fn(x.lower)
|
||||
@ -373,7 +372,7 @@ class ValueRanges(Generic[_T]):
|
||||
return ValueRanges(min(l, u), max(l, u))
|
||||
|
||||
@staticmethod
|
||||
def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
|
||||
def convex_min_zero_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR:
|
||||
"""Fn is convex and has a minimum at 0."""
|
||||
x = ValueRanges.wrap(x)
|
||||
if 0 in x:
|
||||
@ -387,23 +386,23 @@ class ValueRanges(Generic[_T]):
|
||||
@overload
|
||||
@staticmethod
|
||||
def coordinatewise_increasing_map(
|
||||
x: Union[ExprIn, ExprVR],
|
||||
y: Union[ExprIn, ExprVR],
|
||||
x: ExprIn | ExprVR,
|
||||
y: ExprIn | ExprVR,
|
||||
fn: ExprFn2,
|
||||
) -> ExprVR: ...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def coordinatewise_increasing_map( # type: ignore[misc]
|
||||
x: Union[BoolIn, BoolVR],
|
||||
y: Union[BoolIn, BoolVR],
|
||||
x: BoolIn | BoolVR,
|
||||
y: BoolIn | BoolVR,
|
||||
fn: BoolFn2,
|
||||
) -> BoolVR: ...
|
||||
|
||||
@staticmethod
|
||||
def coordinatewise_increasing_map(
|
||||
x: Union[AllIn, AllVR],
|
||||
y: Union[AllIn, AllVR],
|
||||
x: AllIn | AllVR,
|
||||
y: AllIn | AllVR,
|
||||
fn: AllFn2,
|
||||
) -> AllVR:
|
||||
"""
|
||||
@ -1037,7 +1036,7 @@ class SymPyValueRangeAnalysis:
|
||||
|
||||
|
||||
def bound_sympy(
|
||||
expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None
|
||||
expr: sympy.Expr, ranges: dict[sympy.Symbol, ValueRanges] | None = None
|
||||
) -> ValueRanges:
|
||||
log.debug(
|
||||
"bound_sympy(%s)%s",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Generic, Optional, TypeVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
|
||||
R = TypeVar("R")
|
||||
@ -12,8 +12,8 @@ class Thunk(Generic[R]):
|
||||
function once it is forced.
|
||||
"""
|
||||
|
||||
f: Optional[Callable[[], R]]
|
||||
r: Optional[R]
|
||||
f: Callable[[], R] | None
|
||||
r: R | None
|
||||
|
||||
__slots__ = ["f", "r"]
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ import os.path
|
||||
import tempfile
|
||||
import traceback
|
||||
from types import TracebackType
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# This file contains utilities for ensuring dynamically compile()'d
|
||||
@ -234,7 +233,7 @@ class CapturedTraceback:
|
||||
import torch._C._profiler
|
||||
|
||||
# Directly populate tracebacks that already have cached summaries
|
||||
rs: list[Optional[list[str]]] = []
|
||||
rs: list[list[str] | None] = []
|
||||
delayed_idxs = []
|
||||
for i, tb in enumerate(tbs):
|
||||
if tb.tb is None:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Miscellaneous utilities to aid with typing."""
|
||||
|
||||
from typing import Optional, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
|
||||
# Helper to turn Optional[T] into T when we know None either isn't
|
||||
@ -8,7 +8,7 @@ from typing import Optional, TypeVar
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def not_none(obj: Optional[T]) -> T:
|
||||
def not_none(obj: T | None) -> T:
|
||||
if obj is None:
|
||||
raise TypeError("Invariant encountered: value was None when it should not be")
|
||||
return obj
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend
|
||||
@ -90,7 +89,7 @@ def _check_register_once(module, attr) -> None:
|
||||
|
||||
|
||||
def _normalization_device(
|
||||
custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None
|
||||
custom_backend_name: str, device: int | str | torch.device | None = None
|
||||
) -> int:
|
||||
def _get_current_device_index():
|
||||
_get_device_index = "current_device"
|
||||
@ -137,7 +136,7 @@ def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||
|
||||
def wrap_tensor_to(
|
||||
self: torch.Tensor,
|
||||
device: Optional[Union[int, torch.device]] = None,
|
||||
device: int | torch.device | None = None,
|
||||
non_blocking=False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@ -188,7 +187,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||
|
||||
def wrap_module_to(
|
||||
self: torch.nn.modules.module.T,
|
||||
device: Optional[Union[int, torch.device]] = None,
|
||||
device: int | torch.device | None = None,
|
||||
) -> torch.nn.modules.module.T:
|
||||
r"""Move all model parameters and buffers to the custom device.
|
||||
|
||||
@ -268,7 +267,7 @@ def _generate_packed_sequence_methods_for_privateuse1_backend(
|
||||
|
||||
|
||||
def _generate_storage_methods_for_privateuse1_backend(
|
||||
custom_backend_name: str, unsupported_dtype: Optional[list[torch.dtype]] = None
|
||||
custom_backend_name: str, unsupported_dtype: list[torch.dtype] | None = None
|
||||
) -> None:
|
||||
# Attribute is registered in the _StorageBase class
|
||||
# and UntypedStorage obtains through inheritance.
|
||||
@ -355,7 +354,7 @@ def generate_methods_for_privateuse1_backend(
|
||||
for_module: bool = True,
|
||||
for_packed_sequence: bool = True,
|
||||
for_storage: bool = False,
|
||||
unsupported_dtype: Optional[list[torch.dtype]] = None,
|
||||
unsupported_dtype: list[torch.dtype] | None = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Automatically generate attributes and methods for the custom backend after rename privateuse1 backend.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -20,7 +20,7 @@ _POW_TWO_SIZES = tuple(2 ** i for i in range(
|
||||
))
|
||||
|
||||
class UnaryOpSparseFuzzer(Fuzzer):
|
||||
def __init__(self, seed: Optional[int], dtype: _dtype | None = None, cuda: bool = False) -> None:
|
||||
def __init__(self, seed: int | None, dtype: _dtype | None = None, cuda: bool = False) -> None:
|
||||
if dtype is None:
|
||||
dtype = getattr(torch, 'float32', None)
|
||||
super().__init__(
|
||||
|
||||
@ -8,7 +8,7 @@ import shutil
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from typing import cast, Any, Optional
|
||||
from typing import cast, Any
|
||||
from collections.abc import Iterable, Iterator
|
||||
import uuid
|
||||
|
||||
@ -34,10 +34,10 @@ class TaskSpec:
|
||||
stmt: str
|
||||
setup: str
|
||||
global_setup: str = ""
|
||||
label: Optional[str] = None
|
||||
sub_label: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
env: Optional[str] = None
|
||||
label: str | None = None
|
||||
sub_label: str | None = None
|
||||
description: str | None = None
|
||||
env: str | None = None
|
||||
num_threads: int = 1
|
||||
|
||||
@property
|
||||
@ -82,7 +82,7 @@ class Measurement:
|
||||
number_per_run: int
|
||||
raw_times: list[float]
|
||||
task_spec: TaskSpec
|
||||
metadata: Optional[dict[Any, Any]] = None # Reserved for user payloads.
|
||||
metadata: dict[Any, Any] | None = None # Reserved for user payloads.
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._sorted_times: tuple[float, ...] = ()
|
||||
@ -297,7 +297,7 @@ def set_torch_threads(n: int) -> Iterator[None]:
|
||||
torch.set_num_threads(prior_num_threads)
|
||||
|
||||
|
||||
def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> str:
|
||||
def _make_temp_dir(prefix: str | None = None, gc_dev_shm: bool = False) -> str:
|
||||
"""Create a temporary directory. The caller is responsible for cleanup.
|
||||
|
||||
This function is conceptually similar to `tempfile.mkdtemp`, but with
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user