Compare commits

..

5 Commits

Author SHA1 Message Date
f3d17e14ec Update base for Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="820" height="431" alt="image" src="https://github.com/user-attachments/assets/f53b4b24-2523-42cc-ad6f-41f3c280db70" />


We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-09 23:17:16 -08:00
d5bc6fdce6 Update base for Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="820" height="431" alt="image" src="https://github.com/user-attachments/assets/f53b4b24-2523-42cc-ad6f-41f3c280db70" />


We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-09 16:53:40 -08:00
b80c03d695 Update base for Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="820" height="431" alt="image" src="https://github.com/user-attachments/assets/f53b4b24-2523-42cc-ad6f-41f3c280db70" />


We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-07 16:10:36 -08:00
f82985890f Update base for Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="820" height="431" alt="image" src="https://github.com/user-attachments/assets/f53b4b24-2523-42cc-ad6f-41f3c280db70" />


We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-07 09:22:58 -08:00
4ac82f37b9 Update base for Update on "[DTensor] Support convert StridedShard to shard order and vice versa"
We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`.

I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review.

### How to convert shard order to StridedShard:
Considering the example:
- placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim.

Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard.

Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$:
<img width="820" height="431" alt="image" src="https://github.com/user-attachments/assets/f53b4b24-2523-42cc-ad6f-41f3c280db70" />


We can use this method to decide on the split factor for $x_2$, $x_1$ and so on.

### How to convert StridedShard to shard order:
This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order.

---



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-06 23:09:28 -08:00
51 changed files with 1016 additions and 1956 deletions

View File

@ -260,8 +260,8 @@ case "$tag" in
HALIDE=yes
TRITON=yes
;;
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
CUDA_VERSION=12.8.1
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
CUDA_VERSION=13.0.0
ANACONDA_PYTHON_VERSION=3.12
GCC_VERSION=11
PALLAS=yes

View File

@ -168,16 +168,14 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
# shellcheck disable=SC1091
source /opt/intel/oneapi/compiler/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/umf/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/ccl/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/mpi/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/pti/latest/env/vars.sh
# Enable XCCL build
export USE_XCCL=1
export USE_MPI=0
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
export USE_KINETO=0
export TORCH_XPU_ARCH_LIST=pvc
fi

View File

@ -208,8 +208,6 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
source /opt/intel/oneapi/ccl/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/mpi/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/pti/latest/env/vars.sh
# Check XPU status before testing
timeout 30 xpu-smi discovery || true
fi
@ -339,7 +337,7 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}

View File

@ -1 +1 @@
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9

View File

@ -67,7 +67,7 @@ jobs:
pytorch-linux-jammy-py3.10-gcc11,
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-noble-xpu-n-py3,
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,

View File

@ -86,8 +86,8 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
build-environment: linux-jammy-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
cuda-arch-list: '8.9'
runner: linux.8xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"

1
.gitignore vendored
View File

@ -127,7 +127,6 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

View File

@ -94,11 +94,6 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
}
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
}
} // namespace at::accelerator
namespace at {

View File

@ -157,8 +157,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
DispatchKey::Negative,
DispatchKey::Conjugate,
DispatchKey::XLA,
DispatchKey::XPU,
DispatchKey::HPU,
DispatchKey::CUDA,
DispatchKey::CPU,
DispatchKey::PrivateUse1,

View File

@ -4292,7 +4292,6 @@
dispatch:
SparseCPU: sparse_sparse_matmul_cpu
SparseCUDA: sparse_sparse_matmul_cuda
SparseMPS: sparse_sparse_matmul_mps
autogen: _sparse_sparse_matmul.out
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
@ -9833,7 +9832,7 @@
structured_delegate: erfinv.out
variants: method, function
dispatch:
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse
SparseCPU, SparseCUDA: erfinv_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
tags: pointwise
@ -9842,7 +9841,7 @@
structured_delegate: erfinv.out
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_
SparseCPU, SparseCUDA: erfinv_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
tags: pointwise
@ -9852,7 +9851,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: erfinv_out
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out
SparseCPU, SparseCUDA: erfinv_sparse_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
tags: pointwise

View File

@ -10,10 +10,6 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_coalesce_native.h>
#include <ATen/ops/repeat_interleave_native.h>
#include <ATen/ops/cumsum.h>
#include <ATen/ops/_sparse_sparse_matmul_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/add_native.h>
@ -892,114 +888,5 @@ static void sparse_mask_intersection_out_mps_kernel(
/*coalesce_mask=*/false);
}
Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(),
"sparse_sparse_matmul_mps: both inputs must be sparse COO tensors");
TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(),
"sparse_sparse_matmul_mps: both inputs must be on MPS device");
TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2,
"sparse_sparse_matmul_mps: both inputs must be 2D matrices");
TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0,
"sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)");
TORCH_CHECK(mat1_.size(1) == mat2_.size(0),
"mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
"sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(),
" does not match mat2 dtype ", mat2_.scalar_type());
const auto device = mat1_.device();
auto A = mat1_.coalesce();
auto B = mat2_.coalesce();
const auto I = A.size(0);
const auto K = A.size(1);
const auto N = B.size(1);
const auto nnzA = A._nnz();
const auto nnzB = B._nnz();
// Early empty result, return an empty, coalesced tensor
if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) {
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
const auto computeDtype = at::result_type(mat1_, mat2_);
auto A_idx = A._indices().contiguous();
auto A_val = A._values().to(computeDtype).contiguous();
auto A_i = A_idx.select(0, 0).contiguous();
auto A_k = A_idx.select(0, 1).contiguous();
auto B_idx = B._indices().contiguous();
auto B_val = B._values().to(computeDtype).contiguous();
auto B_k = B_idx.select(0, 0).contiguous();
auto B_j = B_idx.select(0, 1).contiguous();
// csr-style row pointers for B by k (the shared dimension)
Tensor row_ptr_B;
{
auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong));
row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong));
build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B);
}
auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K);
auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K);
auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo);
auto counts = deg_B.index_select(0, A_k);
const int64_t P = counts.sum().item<int64_t>();
if (P == 0) {
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
auto group_ids = repeat_interleave_mps(counts);
// exclusive cumsum of counts
auto offsets = cumsum(counts, /*dim=*/0).sub(counts);
auto offsets_gather = offsets.index_select(0, group_ids);
auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather);
// Map each output element to its source B row and position
auto k_per_out = A_k.index_select(0, group_ids);
auto start_in_B = row_ptr_B.index_select(0, k_per_out);
auto seg_index = start_in_B.add(within);
// Assemble candidate coo pairs and values
auto i_out = A_i.index_select(0, group_ids).contiguous();
auto j_out = B_j.index_select(0, seg_index).contiguous();
auto vA_out = A_val.index_select(0, group_ids).contiguous();
auto vB_out = B_val.index_select(0, seg_index).contiguous();
auto v_out = vA_out.mul(vB_out);
// build (2, P) indices
auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous();
out_indices.select(0, 0).copy_(i_out);
out_indices.select(0, 1).copy_(j_out);
auto result = _sparse_coo_tensor_unsafe(
out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype));
result = result.coalesce();
if (result.scalar_type() != mat1_.scalar_type()) {
auto cast_vals = result._values().to(mat1_.scalar_type());
auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
return result;
}
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
} // namespace at::native

View File

@ -96,10 +96,6 @@ struct C10_API DeviceAllocator : public c10::Allocator {
// Resets peak memory usage statistics for the specified device
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
// Return the free memory size and total memory size in bytes for the
// specified device.
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
};
// This function is used to get the DeviceAllocator for a specific device type

View File

@ -345,13 +345,6 @@ class CUDAAllocator : public DeviceAllocator {
c10::DeviceIndex device,
std::shared_ptr<AllocatorState> pps) = 0;
virtual std::string name() = 0;
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
c10::DeviceGuard device_guard({at::kCUDA, device});
size_t free = 0;
size_t total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
return {free, total};
}
};
// Allocator object, statically initialized

View File

@ -926,14 +926,15 @@ class DeviceCachingAllocator {
(release_cached_blocks() && alloc_block(params, true));
}
if (!block_found) {
const auto& raw_device = c10::xpu::get_raw_device(device);
const auto device_total =
raw_device.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device);
auto device_total = device_prop.global_mem_size;
// Estimate the available device memory when the SYCL runtime does not
// support the corresponding aspect (ext_intel_free_memory).
size_t device_free = device_total -
size_t device_free = device_prop.global_mem_size -
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto& raw_device = c10::xpu::get_raw_device(device);
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
// affected devices.
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
@ -1051,37 +1052,21 @@ class DeviceCachingAllocator {
}
}
std::pair<size_t, size_t> getMemoryInfo() {
const auto& device = c10::xpu::get_raw_device(device_index);
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
device.get_info<sycl::info::device::name>(),
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
const size_t free =
device.get_info<sycl::ext::intel::info::device::free_memory>();
return {free, total};
}
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_total);
static_cast<double>(device_prop.global_mem_size);
}
void setMemoryFraction(double fraction) {
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
auto device_total = device_prop.global_mem_size;
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
set_fraction = true;
}
@ -1255,11 +1240,6 @@ class XPUAllocator : public DeviceAllocator {
c10::xpu::get_raw_device(dev_to_access));
}
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
assertValidDevice(device);
return device_allocators[device]->getMemoryInfo();
}
double getMemoryFraction(DeviceIndex device) {
assertValidDevice(device);
return device_allocators[device]->getMemoryFraction();

View File

@ -40,7 +40,6 @@
:nosignatures:
empty_cache
get_memory_info
max_memory_allocated
max_memory_reserved
memory_allocated

View File

@ -382,6 +382,20 @@ coverage_ignore_functions = [
# torch.ao.quantization.backend_config.tensorrt
"get_tensorrt_backend_config",
"get_tensorrt_backend_config_dict",
# torch.ao.quantization.backend_config.utils
"entry_to_pretty_str",
"get_fused_module_classes",
"get_fuser_method_mapping",
"get_fusion_pattern_to_extra_inputs_getter",
"get_fusion_pattern_to_root_node_getter",
"get_module_to_qat_module",
"get_pattern_to_dtype_configs",
"get_pattern_to_input_type_to_index",
"get_qat_module_classes",
"get_root_module_to_quantized_reference_module",
"pattern_to_human_readable",
"remove_boolean_dispatch_from_name",
# torch.ao.quantization.backend_config.x86
"get_x86_backend_config",
# torch.ao.quantization.fuse_modules
"fuse_known_modules",
@ -412,6 +426,25 @@ coverage_ignore_functions = [
"insert_observers_for_model",
"prepare",
"propagate_dtypes_for_known_nodes",
# torch.ao.quantization.fx.utils
"all_node_args_except_first",
"all_node_args_have_no_tensors",
"assert_and_get_unique_device",
"collect_producer_nodes",
"create_getattr_from_value",
"create_node_from_old_node_preserve_meta",
"get_custom_module_class_keys",
"get_linear_prepack_op_for_dtype",
"get_new_attr_name_with_prefix",
"get_non_observable_arg_indexes_and_types",
"get_qconv_prepack_op",
"get_skipped_module_name_and_classes",
"graph_module_from_producer_nodes",
"maybe_get_next_module",
"node_arg_is_bias",
"node_arg_is_weight",
"return_arg_list",
# torch.ao.quantization.pt2e.graph_utils
"bfs_trace_with_node_process",
"find_sequential_partitions",
"get_equivalent_types",
@ -827,10 +860,80 @@ coverage_ignore_functions = [
"get_latency_of_one_partition",
"get_latency_of_partitioned_graph",
"get_partition_to_latency_mapping",
# torch.fx.experimental.proxy_tensor
"decompose",
"disable_autocast_cache",
"disable_proxy_modes_tracing",
"dispatch_trace",
"extract_val",
"fake_signature",
"fetch_sym_proxy",
"fetch_object_proxy",
"get_innermost_proxy_mode",
"get_isolated_graphmodule",
"get_proxy_slot",
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"maybe_handle_decomp",
"proxy_call",
"set_meta",
"set_original_aten_op",
"set_proxy_slot",
"snapshot_fake",
"thunkify",
"track_tensor",
"track_tensor_tree",
"wrap_key",
"wrapper_and_args_for_make_fx",
# torch.fx.experimental.recording
"record_shapeenv_event",
"replay_shape_env_events",
"shape_env_check_state_equal",
# torch.fx.experimental.sym_node
"ceil_impl",
"floor_ceil_helper",
"floor_impl",
"method_to_operator",
"sympy_is_channels_last_contiguous_2d",
"sympy_is_channels_last_contiguous_3d",
"sympy_is_channels_last_strides_2d",
"sympy_is_channels_last_strides_3d",
"sympy_is_channels_last_strides_generic",
"sympy_is_contiguous",
"sympy_is_contiguous_generic",
"to_node",
"wrap_node",
"sym_sqrt",
# torch.fx.experimental.symbolic_shapes
"bind_symbols",
"cast_symbool_to_symint_guardless",
"create_contiguous",
"error",
"eval_guards",
"eval_is_non_overlapping_and_dense",
"expect_true",
"find_symbol_binding_fx_nodes",
"free_symbols",
"free_unbacked_symbols",
"fx_placeholder_targets",
"fx_placeholder_vals",
"guard_bool",
"guard_float",
"guard_int",
"guard_scalar",
"has_hint",
"has_symbolic_sizes_strides",
"is_channels_last_contiguous_2d",
"is_channels_last_contiguous_3d",
"is_channels_last_strides_2d",
"is_channels_last_strides_3d",
"is_contiguous",
"is_non_overlapping_and_dense_indicator",
"is_nested_int",
"is_symbol_binding_fx_node",
"is_symbolic",
# torch.fx.experimental.unification.core
"reify",
# torch.fx.experimental.unification.match
"edge",
@ -868,6 +971,24 @@ coverage_ignore_functions = [
"reverse_dict",
# torch.fx.experimental.unification.multipledispatch.variadic
"isvariadic",
# torch.fx.experimental.unification.unification_tools
"assoc",
"assoc_in",
"dissoc",
"first",
"get_in",
"getter",
"groupby",
"itemfilter",
"itemmap",
"keyfilter",
"keymap",
"merge",
"merge_with",
"update_in",
"valfilter",
"valmap",
# torch.fx.experimental.unification.utils
"freeze",
"hashable",
"raises",
@ -1308,8 +1429,319 @@ coverage_ignore_functions = [
# torch.onnx.symbolic_opset7
"max",
"min",
# torch.onnx.symbolic_opset8
"addmm",
"bmm",
"empty",
"empty_like",
"flatten",
"full",
"full_like",
"gt",
"lt",
"matmul",
"mm",
"ones",
"ones_like",
"prelu",
"repeat",
"zeros",
"zeros_like",
# torch.onnx.symbolic_opset9
"abs",
"acos",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"adaptive_max_pool1d",
"adaptive_max_pool2d",
"adaptive_max_pool3d",
"add",
"addcmul",
"addmm",
"alias",
"amax",
"amin",
"aminmax",
"arange",
"argmax",
"argmin",
"as_strided",
"as_tensor",
"asin",
"atan",
"atan2",
"avg_pool1d",
"avg_pool2d",
"avg_pool3d",
"baddbmm",
"batch_norm",
"bernoulli",
"bitwise_not",
"bitwise_or",
"bmm",
"broadcast_tensors",
"broadcast_to",
"bucketize",
"cat",
"cdist",
"ceil",
"clamp",
"clamp_max",
"clamp_min",
"clone",
"constant_pad_nd",
"contiguous",
"conv1d",
"conv2d",
"conv3d",
"conv_tbc",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
"convert_element_type",
"convolution",
"cos",
"cosine_similarity",
"cross",
"cumsum",
"detach",
"dim",
"div",
"dot",
"dropout",
"elu",
"embedding",
"embedding_bag",
"empty",
"empty_like",
"eq",
"erf",
"exp",
"expand",
"expand_as",
"eye",
"fill",
"flatten",
"floor",
"floor_divide",
"floordiv",
"frobenius_norm",
"full",
"full_like",
"gather",
"ge",
"gelu",
"get_pool_ceil_padding",
"glu",
"group_norm",
"gru",
"gt",
"hann_window",
"hardshrink",
"hardsigmoid",
"hardswish",
"hardtanh",
"index",
"index_add",
"index_copy",
"index_fill",
"index_put",
"index_select",
"instance_norm",
"is_floating_point",
"is_pinned",
"isnan",
"item",
"kl_div",
"layer_norm",
"le",
"leaky_relu",
"lerp",
"lift",
"linalg_cross",
"linalg_matrix_norm",
"linalg_norm",
"linalg_vector_norm",
"linear",
"linspace",
"log",
"log10",
"log1p",
"log2",
"log_sigmoid",
"log_softmax",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"logit",
"logsumexp",
"lstm",
"lstm_cell",
"lt",
"masked_fill",
"masked_fill_",
"matmul",
"max",
"max_pool1d",
"max_pool1d_with_indices",
"max_pool2d",
"max_pool2d_with_indices",
"max_pool3d",
"max_pool3d_with_indices",
"maximum",
"meshgrid",
"min",
"minimum",
"mish",
"mm",
"movedim",
"mse_loss",
"mul",
"multinomial",
"mv",
"narrow",
"native_layer_norm",
"ne",
"neg",
"new_empty",
"new_full",
"new_ones",
"new_zeros",
"nonzero",
"nonzero_numpy",
"noop_complex_operators",
"norm",
"numel",
"numpy_T",
"one_hot",
"ones",
"ones_like",
"onnx_placeholder",
"overload_by_arg_count",
"pad",
"pairwise_distance",
"permute",
"pixel_shuffle",
"pixel_unshuffle",
"pow",
"prelu",
"prim_constant",
"prim_constant_chunk",
"prim_constant_split",
"prim_data",
"prim_device",
"prim_dtype",
"prim_if",
"prim_layout",
"prim_list_construct",
"prim_list_unpack",
"prim_loop",
"prim_max",
"prim_min",
"prim_shape",
"prim_tolist",
"prim_tuple_construct",
"prim_type",
"prim_unchecked_cast",
"prim_uninitialized",
"rand",
"rand_like",
"randint",
"randint_like",
"randn",
"randn_like",
"reciprocal",
"reflection_pad",
"relu",
"relu6",
"remainder",
"repeat",
"repeat_interleave",
"replication_pad",
"reshape",
"reshape_as",
"rnn_relu",
"rnn_tanh",
"roll",
"rrelu",
"rsqrt",
"rsub",
"scalar_tensor",
"scatter",
"scatter_add",
"select",
"selu",
"sigmoid",
"sign",
"silu",
"sin",
"size",
"slice",
"softmax",
"softplus",
"softshrink",
"sort",
"split",
"split_with_sizes",
"sqrt",
"square",
"squeeze",
"stack",
"std",
"std_mean",
"sub",
"t",
"take",
"tan",
"tanh",
"tanhshrink",
"tensor",
"threshold",
"to",
"topk",
"transpose",
"true_divide",
"type_as",
"unbind",
"unfold",
"unsafe_chunk",
"unsafe_split",
"unsafe_split_with_sizes",
"unsqueeze",
"unsupported_complex_operators",
"unused",
"upsample_bilinear2d",
"upsample_linear1d",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"upsample_trilinear3d",
"var",
"var_mean",
"view",
"view_as",
"where",
"wrap_logical_op_with_cast_to",
"wrap_logical_op_with_negation",
"zero",
"zeros",
"zeros_like",
# torch.onnx.utils
"disable_apex_o2_state_dict_hook",
"export",
"export_to_pretty_string",
"exporter_context",
"is_in_onnx_export",
"model_signature",
"register_custom_op_symbolic",
"select_model_mode_for_export",
"setup_onnx_logging",
"unconvertible_ops",
"unpack_quantized_tensor",
"warn_on_static_input_change",
# torch.onnx.verification
"check_export_model_diff",
"verify",
"verify_aten_graph",
@ -1400,6 +1832,32 @@ coverage_ignore_functions = [
"noop_context_fn",
"set_checkpoint_early_stop",
"set_device_states",
# torch.utils.collect_env
"check_release_file",
"get_cachingallocator_config",
"get_clang_version",
"get_cmake_version",
"get_conda_packages",
"get_cpu_info",
"get_cuda_module_loading_config",
"get_cudnn_version",
"get_env_info",
"get_gcc_version",
"get_gpu_info",
"get_libc_version",
"get_lsb_version",
"get_mac_version",
"get_nvidia_driver_version",
"get_nvidia_smi",
"get_os",
"get_pip_packages",
"get_platform",
"get_pretty_env_info",
"get_python_platform",
"get_running_cuda_version",
"get_windows_version",
"is_xnnpack_available",
"pretty_str",
# torch.utils.cpp_backtrace
"get_cpp_backtrace",
# torch.utils.cpp_extension
@ -1463,6 +1921,52 @@ coverage_ignore_functions = [
"apply_shuffle_seed",
"apply_shuffle_settings",
"get_all_graph_pipes",
# torch.utils.flop_counter
"addmm_flop",
"baddbmm_flop",
"bmm_flop",
"conv_backward_flop",
"conv_flop",
"conv_flop_count",
"convert_num_with_suffix",
"get_shape",
"get_suffix_str",
"mm_flop",
"normalize_tuple",
"register_flop_formula",
"sdpa_backward_flop",
"sdpa_backward_flop_count",
"sdpa_flop",
"sdpa_flop_count",
"shape_wrapper",
"transpose_shape",
# torch.utils.hipify.hipify_python
"add_dim3",
"compute_stats",
"extract_arguments",
"file_add_header",
"file_specific_replacement",
"find_bracket_group",
"find_closure_group",
"find_parentheses_group",
"fix_static_global_kernels",
"get_hip_file_path",
"hip_header_magic",
"hipify",
"is_caffe2_gpu_file",
"is_cusparse_file",
"is_out_of_place",
"is_pytorch_file",
"is_special_file",
"match_extensions",
"matched_files_iter",
"openf",
"preprocess_file_and_save_result",
"preprocessor",
"processKernelLaunches",
"replace_extern_shared",
"replace_math_functions",
"str2bool",
# torch.utils.hooks
"unserializable_hook",
"warn_if_has_hooks",

View File

@ -12,37 +12,6 @@ These APIs are experimental and subject to change without notice.
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
```
## torch.fx.experimental.sym_node
```{eval-rst}
.. currentmodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. automodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
is_channels_last_contiguous_2d
is_channels_last_contiguous_3d
is_channels_last_strides_2d
is_channels_last_strides_3d
is_contiguous
is_non_overlapping_and_dense_indicator
method_to_operator
sympy_is_channels_last_contiguous_2d
sympy_is_channels_last_contiguous_3d
sympy_is_channels_last_strides_2d
sympy_is_channels_last_strides_3d
sympy_is_channels_last_strides_generic
sympy_is_contiguous
sympy_is_contiguous_generic
```
## torch.fx.experimental.symbolic_shapes
```{eval-rst}
@ -100,25 +69,6 @@ These APIs are experimental and subject to change without notice.
rebind_unbacked
resolve_unbacked_bindings
is_accessor_node
cast_symbool_to_symint_guardless
create_contiguous
error
eval_guards
eval_is_non_overlapping_and_dense
find_symbol_binding_fx_nodes
free_symbols
free_unbacked_symbols
fx_placeholder_targets
fx_placeholder_vals
guard_bool
guard_float
guard_int
guard_scalar
has_hint
has_symbolic_sizes_strides
is_nested_int
is_symbol_binding_fx_node
is_symbolic
```
## torch.fx.experimental.proxy_tensor
@ -141,46 +91,4 @@ These APIs are experimental and subject to change without notice.
get_proxy_mode
maybe_enable_thunkify
maybe_disable_thunkify
decompose
disable_autocast_cache
disable_proxy_modes_tracing
extract_val
fake_signature
fetch_object_proxy
fetch_sym_proxy
has_proxy_slot
is_sym_node
maybe_handle_decomp
proxy_call
set_meta
set_original_aten_op
set_proxy_slot
snapshot_fake
```
## torch.fx.experimental.unification.unification_tools
```{eval-rst}
.. currentmodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. automodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
assoc
assoc_in
dissoc
first
keyfilter
keymap
merge
merge_with
update_in
valfilter
valmap

View File

@ -1134,6 +1134,7 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.refinement_types
.. py:module:: torch.fx.experimental.rewriter
.. py:module:: torch.fx.experimental.schema_type_annotation
.. py:module:: torch.fx.experimental.sym_node
.. py:module:: torch.fx.experimental.unification.core
.. py:module:: torch.fx.experimental.unification.dispatch
.. py:module:: torch.fx.experimental.unification.match
@ -1143,6 +1144,7 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
.. py:module:: torch.fx.experimental.unification.unification_tools
.. py:module:: torch.fx.experimental.unification.utils
.. py:module:: torch.fx.experimental.unification.variable
.. py:module:: torch.fx.experimental.unify_refinements

View File

@ -134,23 +134,6 @@ Quantization to work with this as well.
ObservationType
```
## torch.ao.quantization.backend_config.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.backend_config.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
entry_to_pretty_str
pattern_to_human_readable
remove_boolean_dispatch_from_name
```
## torch.ao.quantization.fx.custom_config
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
@ -171,30 +154,6 @@ This module contains a few CustomConfig classes that's used in both eager mode a
StandaloneModuleConfigEntry
```
## torch.ao.quantization.fx.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.fx.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
all_node_args_except_first
all_node_args_have_no_tensors
collect_producer_nodes
create_getattr_from_value
create_node_from_old_node_preserve_meta
graph_module_from_producer_nodes
maybe_get_next_module
node_arg_is_bias
node_arg_is_weight
return_arg_list
```
## torch.ao.quantization.quantizer
```{eval-rst}

View File

@ -19,91 +19,6 @@
swap_tensors
```
# torch.utils.collect_env
```{eval-rst}
.. automodule:: torch.utils.collect_env
```
```{eval-rst}
.. currentmodule:: torch.utils.collect_env
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
check_release_file
is_xnnpack_available
pretty_str
```
# torch.utils.flop_counter
```{eval-rst}
.. automodule:: torch.utils.flop_counter
```
```{eval-rst}
.. currentmodule:: torch.utils.flop_counter
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
baddbmm_flop
bmm_flop
conv_backward_flop
conv_flop
conv_flop_count
register_flop_formula
sdpa_backward_flop
sdpa_backward_flop_count
sdpa_flop
sdpa_flop_count
shape_wrapper
```
# torch.utils.hipify.hipify_python
```{eval-rst}
.. automodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. currentmodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
compute_stats
extract_arguments
file_add_header
file_specific_replacement
find_bracket_group
find_closure_group
find_parentheses_group
fix_static_global_kernels
hip_header_magic
hipify
is_caffe2_gpu_file
is_cusparse_file
is_out_of_place
is_pytorch_file
is_special_file
openf
preprocess_file_and_save_result
preprocessor
processKernelLaunches
replace_extern_shared
replace_math_functions
str2bool
```
<!-- This module needs to be documented. Adding here in the meantime
for tracking purposes -->
```{eval-rst}
@ -128,6 +43,7 @@ for tracking purposes -->
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
.. py:module:: torch.utils.bundled_inputs
.. py:module:: torch.utils.checkpoint
.. py:module:: torch.utils.collect_env
.. py:module:: torch.utils.cpp_backtrace
.. py:module:: torch.utils.cpp_extension
.. py:module:: torch.utils.data.backward_compatibility
@ -164,8 +80,10 @@ for tracking purposes -->
.. py:module:: torch.utils.data.sampler
.. py:module:: torch.utils.dlpack
.. py:module:: torch.utils.file_baton
.. py:module:: torch.utils.flop_counter
.. py:module:: torch.utils.hipify.constants
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
.. py:module:: torch.utils.hipify.hipify_python
.. py:module:: torch.utils.hipify.version
.. py:module:: torch.utils.hooks
.. py:module:: torch.utils.jit.log_extract

View File

@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None:
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
def mirror_inductor_external_kernels() -> None:
"""
Copy external kernels into Inductor so they are importable.
"""
paths = [
(
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
CWD
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
),
]
for new_path, orig_path in paths:
# Create the dirs involved in new_path if they don't exist
if not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location
if orig_path.is_file():
shutil.copyfile(orig_path, new_path)
continue
if orig_path.is_dir():
if new_path.exists():
# copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path)
continue
raise RuntimeError(
"Check the file paths in `mirror_inductor_external_kernels()`"
)
# ATTENTION: THIS IS AI SLOP
def extract_variant_from_version(version: str) -> str:
"""Extract variant from version string, defaulting to 'cpu'."""
@ -1646,7 +1615,6 @@ def main() -> None:
mirror_files_into_torchgen()
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
(
ext_modules,
@ -1681,7 +1649,6 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_inductor/kernel/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -204,16 +204,14 @@ class DistConvolutionOpsTest(DTensorTestBase):
self.assertTrue(b_dt.grad is not None)
self.assertTrue(x_dt.grad is None)
def _run_single_arg_fwd(
self, model, arg, placements=None
) -> tuple[torch.Tensor, torch.Tensor]:
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
device_mesh = self.build_device_mesh()
model_copy = copy.deepcopy(model).to(device=self.device_type)
dist_model = distribute_module(model, device_mesh, _conv_fn)
arg_dt = DTensor.from_local(arg, device_mesh, placements)
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
out_dt = dist_model(arg_dt.to(device=self.device_type))
out = model_copy(arg_dt.full_tensor())
out = model_copy(arg)
return (out_dt.full_tensor(), out)
@with_comms
@ -221,20 +219,22 @@ class DistConvolutionOpsTest(DTensorTestBase):
model = nn.Conv1d(64, 64, 3, padding=1)
x = torch.randn(1, 64, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt, out)
self.assertEqual(out_dt.shape, out.shape)
@with_comms
def test_conv3d(self):
model = nn.Conv3d(64, 64, 3, padding=1)
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
self.assertEqual(out_dt, out)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt.shape, out.shape)
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
DistConvolutionOpsTest,
# Send / recv ops are not supported
skipped_tests=[
"test_conv1d",
"test_conv3d",
"test_conv_backward_none_grad_inp",
"test_depthwise_convolution",
"test_downsampling_convolution",

View File

@ -2,6 +2,7 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import copy
import itertools
import unittest
@ -21,8 +22,9 @@ from torch.distributed.tensor import (
)
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -33,11 +35,7 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
generate_shard_orders,
make_full_tensor,
map_local_tensor_for_rank,
patched_distribute_tensor as _distribute_tensor,
redistribute,
with_comms,
)
from torch.utils._debug_mode import DebugMode
@ -787,6 +785,88 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
else:
return ""
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def redistribute(
self,
dtensor_input,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""
wrapper function to support shard_order for redistribution
This is a simpler version of Redistribute, only considers the forward.
"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
old_spec = dtensor_input._spec
new_spec = copy.deepcopy(old_spec)
new_spec.placements = placements
if shard_order is not None:
new_spec.shard_order = shard_order
else:
new_spec.shard_order = ()
if old_spec == new_spec:
return dtensor_input
dtensor_input = DTensor.from_local(
redistribute_local_tensor(
dtensor_input.to_local(),
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)
return dtensor_input # returns DTensor
# TODO(zpcore): remove once the native distribute_tensor supports
# shard_order arg
def distribute_tensor(
self,
input_tensor,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""wrapper function to support shard_order for tensor distribution"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
# fix the shard order
return self.redistribute(
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def full_tensor(self, dtensor_input):
"""wrapper function to support DTensor.full_tensor"""
return self.redistribute(
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
).to_local()
def _shard_order_to_placement(self, shard_order, mesh):
"""convert shard_order to placement with only Replicate() and Shard()"""
placements = [Replicate() for _ in range(mesh.ndim)]
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
"""Convert shard_order dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
for tensor_dim, mesh_dims in shard_order.items()
)
@with_comms
def test_ordered_redistribute(self):
"""Test ordered redistribution with various sharding syntaxes"""
@ -847,11 +927,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
sharding_src_dst_pairs_with_expected_trace
):
sharded_dt = _distribute_tensor(
sharded_dt = self.distribute_tensor(
input_data.clone(), mesh, src_placement, shard_order=src_order
)
with DebugMode(record_torchfunction=False) as debug_mode:
sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order)
sharded_dt = self.redistribute(
sharded_dt, mesh, dst_placement, dst_order
)
trace_str = self._extract_redistribute_trace_from_debug_mode(
debug_mode.debug_string()
)
@ -875,11 +957,49 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
trace_str,
"""S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
)
expected_dt = _distribute_tensor(
expected_dt = self.distribute_tensor(
input_data.clone(), mesh, dst_placement, shard_order=dst_order
)
self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
def generate_shard_orders(self, mesh, tensor_rank):
# Generate all possible sharding placement of tensor with rank
# `tensor_rank` over mesh.
def _split_list(lst: list, N: int):
def compositions(n, k):
if k == 1:
yield [n]
else:
for i in range(1, n - k + 2):
for tail in compositions(n - i, k - 1):
yield [i] + tail
length = len(lst)
for comp in compositions(length, N):
result = []
start = 0
for size in comp:
result.append(lst[start : start + size])
start += size
yield result
all_mesh = list(range(mesh.ndim))
all_device_order = list(itertools.permutations(all_mesh))
for device_order in all_device_order:
# split on device orders, and assign each device order segment to a tensor dim
for num_split in range(1, mesh.ndim + 1):
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
for tensor_dims in itertools.combinations(
range(tensor_rank), len(splitted_list)
):
shard_order = {}
assert len(tensor_dims) == len(splitted_list)
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
shard_order[tensor_dim] = device_order[
mesh_dims[0] : mesh_dims[-1] + 1
]
yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
@with_comms
def test_generate_shard_orders(self):
"""Check if `generate_shard_orders` generates unique sharding combinations"""
@ -892,7 +1012,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
]
for test_input in test_inputs:
all_combinations = []
for shard_order in generate_shard_orders(
for shard_order in self.generate_shard_orders(
test_input["mesh"], test_input["tensor_rank"]
):
all_combinations.append(shard_order) # noqa: PERF402
@ -942,12 +1062,12 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
input_data = torch.randn(tensor_shape, device=self.device_type)
tensor_rank = input_data.ndim
with maybe_disable_local_tensor_mode():
shard_orders = generate_shard_orders(mesh, tensor_rank)
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
sharded_dt = _distribute_tensor(
sharded_dt = self.distribute_tensor(
input_data.clone(), mesh, placements=None, shard_order=shard_order
)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
# 2. Verify the correctness of redistribution from DTensor to DTensor.
# This test repeatedly redistributes a DTensor to various ordered
@ -958,20 +1078,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
tensor_rank = input_data.ndim
prev_sharded_dt = None
with maybe_disable_local_tensor_mode():
shard_orders = generate_shard_orders(mesh, tensor_rank)
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
if prev_sharded_dt is None:
prev_sharded_dt = _distribute_tensor(
prev_sharded_dt = self.distribute_tensor(
input_data.clone(),
mesh,
placements=None,
shard_order=shard_order,
)
else:
sharded_dt = redistribute(
sharded_dt = self.redistribute(
prev_sharded_dt, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
prev_sharded_dt = sharded_dt
@with_comms
@ -1016,13 +1136,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
local_tensor = torch.randn(shape, device=self.device_type)
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
with maybe_disable_local_tensor_mode():
shard_orders = generate_shard_orders(mesh, len(shape))
shard_orders = self.generate_shard_orders(mesh, len(shape))
for shard_order in shard_orders:
sharded_dt = redistribute(
sharded_dt = self.redistribute(
full_tensor, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(
make_full_tensor(sharded_dt), make_full_tensor(full_tensor)
self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
)
@unittest.skip(
@ -1032,20 +1152,24 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
@with_comms
def test_ordered_redistribute_for_special_placement(self):
"""Test ordered redistribution with special placement"""
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
torch.manual_seed(21)
mesh = init_device_mesh(self.device_type, (8,))
input_data = torch.randn((8, 8), device=self.device_type)
src_placement = [Shard(1)]
tgt_placement = [
(MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
(_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
]
sharded_dt = _distribute_tensor(
sharded_dt = self.distribute_tensor(
input_data.clone(),
mesh,
src_placement,
shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
)
sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None)
sharded_dt = self.redistribute(
sharded_dt, mesh, tgt_placement, shard_order=None
)
@with_comms
def test_shard_order_same_data_as_strided_shard(self):
@ -1055,7 +1179,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
# specify right-to-left order use ordered shard
x_ordered_dt = _distribute_tensor(
x_ordered_dt = self.distribute_tensor(
x,
device_mesh,
placements=[Shard(0), Shard(0)],

View File

@ -34,10 +34,6 @@ from torch.distributed.tensor.placement_types import (
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
generate_shard_orders,
LocalDTensorTestBase,
patched_distribute_tensor as _distribute_tensor,
shard_order_to_placement,
with_comms,
)
@ -778,63 +774,6 @@ class TestStridedSharding(DTensorTestBase):
self.assertEqual(dtensor.full_tensor(), tensor)
class Test_StridedShard_with_shard_order(LocalDTensorTestBase):
@property
def world_size(self) -> int:
return 32
@with_comms
def test_StridedShard_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2))
shard_iter = generate_shard_orders(mesh, 3)
# It takes ~4.8h to complete total 2520 shard order combinations here
# using LocalTensor. So we only randomly pick 25 shard orders to test.
all_shard_order = list(shard_iter)
import random
random.seed(42)
shard_order_choices = random.sample(
all_shard_order, min(25, len(all_shard_order))
)
x = torch.randn(32, 32, 32)
for shard_order in shard_order_choices:
a = _distribute_tensor(x, mesh, None, shard_order)
placement_without_stridedshard = shard_order_to_placement(
shard_order, mesh
)
placements_with_stridedshard = (
DTensorSpec._convert_shard_order_to_StridedShard(
shard_order, placement_without_stridedshard, mesh
)
)
b = distribute_tensor(x, mesh, placements_with_stridedshard)
shard_order_from_stridedshard = (
DTensorSpec._maybe_convert_StridedShard_to_shard_order(
placements_with_stridedshard, mesh
)
)
self.assertEqual(shard_order, shard_order_from_stridedshard)
self.assertEqual(a.to_local(), b.to_local())
@with_comms
def test_StridedShard_not_convertible_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8))
unconvertible_placements_list = [
[_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)],
[_StridedShard(0, split_factor=2), Shard(1)],
[_StridedShard(1, split_factor=16), Shard(1)],
]
for placements in unconvertible_placements_list:
shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order(
tuple(placements), mesh
)
self.assertIsNone(shard_order)
class Test2DStridedLocalShard(DTensorTestBase):
@property
def world_size(self):

View File

@ -206,10 +206,6 @@ class TestPyCodeCache(TestCase):
.decode()
.strip()
)
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
if torch.xpu.is_available():
wrapper_path = wrapper_path.splitlines()[-1]
hit = hit.splitlines()[-1]
self.assertEqual(hit, "1")
with open(wrapper_path) as f:

View File

@ -1,154 +0,0 @@
# Owner(s): ["module: inductor"]
import unittest
import torch
from torch import Tensor
from torch._inductor import config
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import ensure_cute_available
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
@unittest.skipIf(
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
"CuTeDSL library or Blackwell device not available",
)
@instantiate_parametrized_tests
class TestCuTeDSLGroupedGemm(InductorTestCase):
def _get_inputs(
self,
group_size: int,
M_hint: int,
K: int,
N: int,
device: str,
dtype: torch.dtype,
alignment: int = 16,
) -> tuple[Tensor, Tensor, Tensor]:
# --- Random, tile-aligned M sizes ---
M_sizes = (
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
* alignment
)
M_total = torch.sum(M_sizes).item()
# --- Construct input tensors ---
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
# --- Build offsets (no leading zero, strictly increasing) ---
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
return (A, B, offsets)
@parametrize("group_size", (2, 8))
@parametrize("M_hint", (256, 1024))
@parametrize("K", (64, 128))
@parametrize("N", (128, 256))
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
device = "cuda"
dtype = torch.bfloat16
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# Eager execution
c_eager = grouped_gemm_fn(A, B, offsets)
# Test with Cute backend
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
@parametrize("layout_B", ("contiguous", "broadcasted"))
def test_grouped_gemm_assorted_layouts(
self,
layout_A: str,
layout_B: str,
):
device = "cuda"
dtype = torch.bfloat16
G, K, N = 8, 64, 128
M_sizes = [128] * G
sum_M = sum(M_sizes)
offsets = torch.tensor(
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
)
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
A = A_base
if layout_A == "offset":
# allocate bigger buffer than needed, use nonzero storage offset
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
offset = 128 # skip first 128 elements
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
elif layout_A == "padded":
# simulate row pitch > K (row_stride = K + pad)
row_pitch = K + 8
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
elif layout_A == "view":
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
A = A_storage.view(sum_M, K)
assert A._base is not None
assert A.shape == (sum_M, K)
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
if layout_B == "broadcasted":
# Broadcast B across groups (zero stride along G)
B = B[0].expand(G, K, N)
assert B.stride(0) == 0
def grouped_gemm_fn(A_packed, B_batched, offs):
return torch._grouped_mm(A_packed, B_batched, offs=offs)
# --- eager ---
c_eager = grouped_gemm_fn(A, B, offsets)
# --- compiled (CUTE backend) ---
with config.patch(
{
"max_autotune": True,
"max_autotune_gemm_backends": "CUTEDSL",
"test_configs.autotune_choice_name_regex": "cutedsl",
"autotune_fallback_to_aten": False,
}
):
grouped_gemm_compiled = torch.compile(
grouped_gemm_fn, backend="inductor", dynamic=False
)
c_compiled = grouped_gemm_compiled(A, B, offsets)
self.assertEqual(c_eager.dtype, dtype)
self.assertEqual(c_compiled.dtype, dtype)
torch.testing.assert_close(c_eager, c_compiled)
if __name__ == "__main__":
run_tests()

View File

@ -482,8 +482,8 @@ class TestExecutionTrace(TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
@unittest.skipIf(
(not has_triton()) or (not TEST_CUDA),
"need triton and device CUDA availability to run",
(not has_triton()) or (not TEST_CUDA and not TEST_XPU),
"need triton and device(CUDA or XPU) availability to run",
)
@skipCPUIf(True, "skip CPU device for testing profiling triton")
def test_triton_fx_graph_with_et(self, device):

View File

@ -2005,10 +2005,6 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
report = json.load(f)
self._validate_basic_json(report["traceEvents"], with_cuda)
@unittest.skipIf(
torch.xpu.is_available(),
"XPU Trace event ends too late! Refer https://github.com/intel/torch-xpu-ops/issues/2263",
)
@unittest.skipIf(not kineto_available(), "Kineto is required")
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
def test_basic_chrome_trace(self):
@ -2162,10 +2158,7 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
def test_basic_profile(self):
# test a really basic profile to make sure no erroneous aten ops are run
acc = torch.accelerator.current_accelerator()
self.assertIsNotNone(acc)
device = acc.type
x = torch.randn(4, device=device)
x = torch.randn(4, device="cuda")
with torch.profiler.profile(with_stack=True) as p:
x *= 2
names = [e.name for e in p.events()]
@ -2232,7 +2225,6 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
@unittest.skipIf(
torch.cuda.is_available(), "CUDA complains about forking after init"
)
@unittest.skipIf(torch.xpu.is_available(), "XPU complains about forking after init")
@unittest.skipIf(IS_WINDOWS, "can't use os.fork() on Windows")
def test_forked_process(self):
# Induce a pid cache by running the profiler with payload

View File

@ -263,7 +263,13 @@ S390X_BLOCKLIST = [
XPU_BLOCKLIST = [
"test_autograd",
"profiler/test_cpp_thread",
"profiler/test_execution_trace",
"profiler/test_memory_profiler",
"profiler/test_profiler",
"profiler/test_profiler_tree",
"profiler/test_record_function",
"profiler/test_torch_tidy",
"test_openreg",
]

View File

@ -1,236 +1,245 @@
{
"EndToEndLSTM (__main__.RNNTest)": 190.48799641927084,
"MultiheadAttention (__main__.ModulesTest)": 141.2663370768229,
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 82.87333234151204,
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 70.6538565499442,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 123.34033711751302,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 171.25450134277344,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 119.71899922688802,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.35733322870163,
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.64533233642578,
"test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.672952016194664,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 138.04000091552734,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 172.1344985961914,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 114.02050018310547,
"test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.25642830984933,
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.3350003560384,
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 120.95249938964844,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.97774887084961,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.90774917602539,
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1144.3935089111328,
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 222.58500061035156,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.10033162434894,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 517.1875050862631,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 113.88125228881836,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 235.77350616455078,
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.6155014038086,
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.63325119018555,
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 216.2968317667643,
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 153.0915012359619,
"test_cat_2k_args (__main__.TestTEFuserDynamic)": 108.80471753561869,
"test_cat_2k_args (__main__.TestTEFuserStatic)": 102.20949847949669,
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 311.7026621500651,
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 395.0001729329427,
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 348.6218566894531,
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 98.71574974060059,
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 97.68499946594238,
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.0557508468628,
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 65.86899948120117,
"test_comprehensive_gradient_cuda_complex64 (__main__.TestDecompCUDA)": 97.15880012512207,
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 103.20700073242188,
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 102.74033610026042,
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 460.4286702473958,
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 435.62066650390625,
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 287.3090057373047,
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 265.1860008239746,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1235.7365112304688,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.20825004577637,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1281.2615051269531,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.90750026702881,
"test_comprehensive_linalg_householder_product_cuda_complex64 (__main__.TestDecompCUDA)": 79.04633331298828,
"test_comprehensive_linalg_lu_factor_ex_cuda_complex128 (__main__.TestDecompCUDA)": 68.10879821777344,
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.43025207519531,
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.94575023651123,
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.93649864196777,
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.46275043487549,
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 64.10650062561035,
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.03124904632568,
"test_comprehensive_linalg_svd_cuda_float64 (__main__.TestDecompCUDA)": 64.32800025939942,
"test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 96.41353665865384,
"test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 100.17661388103778,
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 110.95025062561035,
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 108.06550025939941,
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 104.24150085449219,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.453749656677246,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 61.739999771118164,
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 69.96549987792969,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 113.65749931335449,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 106.57500076293945,
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 117.54049682617188,
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 116.19766489664714,
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 272.48475646972656,
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 248.12175369262695,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 79.66900062561035,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 81.52649879455566,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 79.29400062561035,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.40349960327148,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 128.42924880981445,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 125.03675079345703,
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1264.9732360839844,
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1250.7332458496094,
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1255.0684814453125,
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 574.4627532958984,
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 581.7282485961914,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 65.052001953125,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 61.19200134277344,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 63.16874885559082,
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 62.39250183105469,
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 113.32574844360352,
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 113.91499900817871,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 74.42549800872803,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 76.1560001373291,
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.76750087738037,
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 70.69724941253662,
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 69.87625026702881,
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 80.2542495727539,
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 69.0419979095459,
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 117.03342655726841,
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 289.50213841029574,
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 67.38800048828125,
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 145.27399444580078,
"test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 66.9245999654134,
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 151.91099548339844,
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 92.79549789428711,
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.60149955749512,
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 69.27724676392972,
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 76.24971498761859,
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 81.93449974060059,
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.87700080871582,
"test_count_nonzero_all (__main__.TestBool)": 631.2585144042969,
"test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 61.042999267578125,
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.49850082397461,
"test_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 93.03299713134766,
"test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 228.46711820714614,
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 286.29998779296875,
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 68.43842806134906,
"test_fail_random.py (__main__.TestTyping)": 74.83523060725285,
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 72.84900093078613,
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 75.86675071716309,
"test_fuse_large_params_cpu (__main__.CpuTests)": 151.4199981689453,
"test_fuse_large_params_cuda (__main__.GPUTests)": 60.351999282836914,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 158.3622828892299,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 149.6796646118164,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.97800064086914,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 114.8385009765625,
"test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 84.69736822027909,
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.62700080871582,
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 89.197998046875,
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 96.46900177001953,
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 187.83824920654297,
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.49449920654297,
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 124.90424919128418,
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 518.4157485961914,
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 304.6440022786458,
"test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 143.82052836698645,
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 77.4985705784389,
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 76.06225109100342,
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 138.9222858973912,
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 120.62233225504558,
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 148.1219940185547,
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 109.34200286865234,
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 119.36233266194661,
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 127.95700073242188,
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 61.64850175380707,
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 105.3174296787807,
"test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 585.9210001627604,
"test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 504.3250020345052,
"test_lstm_cpu (__main__.TestMkldnnCPU)": 86.21566645304362,
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 129.277715410505,
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 64.24800109863281,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 77.23899841308594,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_True (__main__.TestCKBackend)": 65.15649795532227,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 62.579833984375,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.6555004119873,
"test_pattern_matcher_multi_user_cpu (__main__.CpuTritonTests)": 142.21566772460938,
"test_proper_exit (__main__.TestDataLoader)": 267.74214717320035,
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 266.6539971487863,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 101.97100067138672,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.3346659342448,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.50300216674805,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 104.61333465576172,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.41133371988933,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.37100219726562,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.30900065104167,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.61750030517578,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 79.33600234985352,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 101.2393315633138,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 103.18400192260742,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.4114990234375,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.52833302815755,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.72700119018555,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 100.61966705322266,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.2750015258789,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.17449951171875,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.96749877929688,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 106.44049835205078,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 101.7173334757487,
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 531.5236612955729,
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1077.4210205078125,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 812.0880126953125,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1347.9365234375,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 88.93533070882161,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 269.01949310302734,
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.99799601236978,
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 232.36275100708008,
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 69.80400085449219,
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 134.3415012359619,
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.51749992370605,
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 91.21066792805989,
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 170.97775268554688,
"test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.608266321818036,
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.62575149536133,
"test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 63.59499969482422,
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 88.68299865722656,
"test_rnn_decomp_module_nn_LSTM_train_mode_cuda_float32 (__main__.TestDecompCUDA)": 91.50320053100586,
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 66.10774898529053,
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 66.20533180236816,
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 243.1092529296875,
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 105.01200103759766,
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.93685695103237,
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 142.38899993896484,
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 119.90166600545247,
"test_sort_bool_cpu (__main__.CpuTritonTests)": 346.2856750488281,
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 423.09974098205566,
"test_sort_stable_cuda (__main__.GPUTests)": 117.61659927368164,
"test_sort_transpose_cpu (__main__.CpuTritonTests)": 378.31200154622394,
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 222.822007894516,
"test_terminate_handler_on_crash (__main__.TestTorch)": 143.31728431156702,
"test_terminate_signal (__main__.ForkTest)": 168.20485967184817,
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 168.19242484867573,
"test_terminate_signal (__main__.SpawnTest)": 172.16428443363733,
"test_thnn_conv_strided_padded_dilated (__main__.TestConvolutionNN)": 93.30639710426331,
"test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 163.89743041992188,
"test_train_parity_with_activation_checkpointing (__main__.TestFullyShard1DTrainingCompose)": 60.47671399797712,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 63.39550018310547,
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 173.53924942016602,
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 175.3212537765503,
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.20649909973145,
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 99.9885025024414,
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 71.64024829864502,
"test_view_ops (__main__.TestViewOpsWithLocalTensor)": 73.45887422561646,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.75249862670898,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 61.858001708984375,
"test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 65.11023766653878,
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 66.35274982452393,
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 61.196499824523926,
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.75380906604585,
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 73.64649868011475,
"test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 75.09799966358003,
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.51450157165527,
"test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.21433276221866,
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 73.20024871826172,
"test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 88.1349983215332,
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.89924907684326,
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.32975196838379,
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 120.09600067138672
"EndToEndLSTM (__main__.RNNTest)": 207.89400227864584,
"MultiheadAttention (__main__.ModulesTest)": 141.1396687825521,
"test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 214.02366638183594,
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 77.26125049591064,
"test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 116.37000020345052,
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 69.25722334120009,
"test_after_aot_gpu_runtime_error (__main__.MinifierIsolateTests)": 65.84466807047527,
"test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 178.41399637858072,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.55014337812151,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 122.18047623407273,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 192.6405719575428,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.27904801141648,
"test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.906999588012695,
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 62.244998931884766,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 150.04100036621094,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 191.85050201416016,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.9276631673177,
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.31450271606445,
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 125.24066416422527,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.47783279418945,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.46250025431316,
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1031.0534973144531,
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 239.67400105794272,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 495.0447726779514,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 490.18524169921875,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 144.06477737426758,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 342.20416259765625,
"test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 62.01366678873698,
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 71.07200050354004,
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 73.9221674601237,
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 226.0122528076172,
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 144.97249857584634,
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 303.20537185668945,
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 386.0518798828125,
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 291.2442270914714,
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 95.87866719563802,
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 98.38716634114583,
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 69.08016649881999,
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 69.88233311971028,
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 104.17599995930989,
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 97.41800308227539,
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 474.6719970703125,
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 440.4375,
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 293.3983332316081,
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 238.7328338623047,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1218.4906717936199,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.73516782124837,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1156.0123494466145,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.13916714986165,
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.90450032552083,
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.42100016276042,
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.98883310953777,
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 73.34433364868164,
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.38016573588053,
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.52783330281575,
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 111.06333287556966,
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 110.19833374023438,
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.10083134969075,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.23766644795736,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 70.18666712443034,
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 62.61399841308594,
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 67.7816670735677,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 121.6183344523112,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 107.30266698201497,
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 130.8143310546875,
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 127.27633412679036,
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 303.55183664957684,
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 234.41216532389322,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 85.3436673482259,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 80.9688326517741,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 82.55149968465169,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.37966791788737,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 129.88233184814453,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 129.4015007019043,
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1282.3826497395833,
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1270.64599609375,
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1297.9046630859375,
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 545.2034962972006,
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 572.5616760253906,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 64.40316645304362,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.68383344014485,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.48333422342936,
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 61.959999084472656,
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 105.79100036621094,
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 122.34666570027669,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 68.7205015818278,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.2183329264323,
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.86883227030437,
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 77.48183314005534,
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 79.1564998626709,
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 160.41250228881836,
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 79.10633341471355,
"test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 60.106833140055336,
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 221.3586196899414,
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 504.3203754425049,
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 78.03233337402344,
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 152.302001953125,
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 152.99433390299478,
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 96.25399971008301,
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 75.70275068283081,
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 139.14399747674665,
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 72.7847490310669,
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 91.59966786702473,
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 87.57833353678386,
"test_count_nonzero_all (__main__.TestBool)": 664.9986343383789,
"test_cp_flex_attention_document_mask (__main__.CPFlexAttentionTest)": 78.31500244140625,
"test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 385.24249792099,
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.70466740926106,
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 685.0679931640625,
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 86.26266733805339,
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 292.93699645996094,
"test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 66.84199905395508,
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 69.56212568283081,
"test_fail_creation_ops.py (__main__.TestTyping)": 69.80560022989908,
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 73.36666552225749,
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 90.40366744995117,
"test_fuse_large_params_cpu (__main__.CpuTests)": 132.73199844360352,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 150.16662406921387,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 159.28499794006348,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 165.19283294677734,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 151.12366739908853,
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.61699930826823,
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.00600179036458,
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.3759994506836,
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 190.89249674479166,
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 149.6598358154297,
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 146.07766723632812,
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 532.8139902750651,
"test_graph_partition_refcount_cuda (__main__.GPUTests)": 69.78400001525878,
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 267.04988850487604,
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 273.54955800374347,
"test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.84733072916666,
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 326.0143330891927,
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.96037435531616,
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 77.44933319091797,
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 126.81488884819879,
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 118.70199839274089,
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 129.20266723632812,
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 97.18800099690755,
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 130.3183339436849,
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 140.43233235677084,
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 293.122774971856,
"test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 63.835832277933754,
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.77049922943115,
"test_lstm_cpu (__main__.TestMkldnnCPU)": 100.89649963378906,
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 140.07424926757812,
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 72.90299733479817,
"test_max_autotune_addmm_search_space_EXHAUSTIVE_dynamic_True (__main__.TestMaxAutotuneSubproc)": 82.62433369954427,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 87.51499938964844,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_True_use_aoti_True (__main__.TestCKBackend)": 71.22416591644287,
"test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 424.50966389973956,
"test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 134.14600626627603,
"test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 358.88099161783856,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.58866712782118,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 62.68674945831299,
"test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 65.85794713936355,
"test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTest)": 103.6923344930013,
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTest)": 187.6953328450521,
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 370.27442932128906,
"test_proper_exit (__main__.TestDataLoader)": 227.83111148410373,
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 227.1901126437717,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.52099990844727,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.50249862670898,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 92.52400207519531,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.75499725341797,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 107.40500259399414,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 83.80450057983398,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 107.46599833170573,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.65650177001953,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 83.4114990234375,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.47100067138672,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 108.55533345540364,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 89.23666381835938,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.13900375366211,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 100.14550018310547,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.33649826049805,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.08150100708008,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 97.59600067138672,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 104.82933553059895,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 114.43099721272786,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 110.40333302815755,
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 567.2765197753906,
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1032.5083312988281,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 852.7170003255209,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1361.954854329427,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 77.385498046875,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 265.0193354288737,
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 115.31749725341797,
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 245.27666727701822,
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 71.75300216674805,
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 141.8895009358724,
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 71.15749994913737,
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 90.59066772460938,
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 173.73916625976562,
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.65066655476888,
"test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 99.21799850463867,
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 90.86299896240234,
"test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.57050196329753,
"test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 69.65149958928426,
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 78.13350168863933,
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 76.85255601671007,
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 333.04866282145184,
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 146.96599833170572,
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 160.4881100124783,
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 124.10055626763238,
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 117.38410907321506,
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 710.2327779134115,
"test_sort_stable_cpu (__main__.CpuTritonTests)": 1324.4399820963542,
"test_sort_stable_cuda (__main__.GPUTests)": 76.83109970092774,
"test_split_cumsum_cpu (__main__.CpuTritonTests)": 88.58433532714844,
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 160.1271684964498,
"test_tensor_split (__main__.TestVmapOperators)": 79.18955569393519,
"test_terminate_handler_on_crash (__main__.TestTorch)": 111.30388899644215,
"test_terminate_signal (__main__.ForkTest)": 132.3458870516883,
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 132.2043343567186,
"test_terminate_signal (__main__.SpawnTest)": 136.1005539894104,
"test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 76.20899939537048,
"test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 63.82099969046457,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 61.925000508626304,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 60.89849980672201,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 66.88233375549316,
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.9854990641276,
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.4044977823893,
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 108.19166437784831,
"test_unary_ops (__main__.TestTEFuserDynamic)": 96.32655514611139,
"test_unary_ops (__main__.TestTEFuserStatic)": 105.33362591266632,
"test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 97.8336664835612,
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 82.86566925048828,
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 68.26500002543132,
"test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 97.1120007832845,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 88.24766794840495,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 65.41266759236653,
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 74.75533294677734,
"test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 73.52500089009602,
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.85466639200847,
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 98.39650090535481,
"test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 61.39695285615467,
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 77.88249842325847,
"test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.0695006052653,
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 81.86250114440918,
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 98.63116455078125,
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 94.85683314005534,
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 173.00183614095053
}

View File

@ -239,12 +239,6 @@ class TestAccelerator(TestCase):
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
def test_get_memory_info(self):
free_bytes, total_bytes = torch.accelerator.get_memory_info()
self.assertGreaterEqual(free_bytes, 0)
self.assertGreaterEqual(total_bytes, 0)
if __name__ == "__main__":
run_tests()

View File

@ -3728,6 +3728,7 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(*floating_and_complex_types())
@dtypesIfMPS(*all_mps_types())
@expectedFailureMPS
@dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater and not TEST_WITH_ROCM else [],
*[torch.bfloat16] if SM80OrLater and not TEST_WITH_ROCM else [],
torch.complex64,
@ -3824,9 +3825,9 @@ class TestSparse(TestSparseBase):
def different_dtypes():
a, i_a, v_a = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
b, i_b, v_b = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
r2 = torch.sparse.mm(a.to(torch.float32), a.to(torch.float16))
r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32))
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Float does not match mat2 dtype Half', different_dtypes)
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
def test_backward_noncontiguous():
# Sparse.mm backward used to wrong with non-contiguous grads,

View File

@ -206,8 +206,7 @@ if __name__ == "__main__":
test_multi_process(model, input)
print(torch.xpu.device_count())
"""
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
rc = check_output(test_script).splitlines()[-1]
rc = check_output(test_script)
self.assertEqual(rc, str(torch.xpu.device_count()))
def test_streams(self):

View File

@ -2491,7 +2491,6 @@ def _accelerator_emptyCache() -> None: ...
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ...
def _accelerator_setAllocatorSettings(env: str) -> None: ...
# Defined in torch/csrc/jit/python/python_tracer.cpp

View File

@ -550,10 +550,6 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
).upper() # type: ignore[assignment]
cutedsl_enable_autotuning: bool = (
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
)
# DEPRECATED. This setting is ignored.
autotune_fallback_to_aten = False

View File

@ -1,8 +1,6 @@
# mypy: allow-untyped-defs
import logging
from collections.abc import Sequence
from functools import partial
from pathlib import Path
from typing import Any
import torch
@ -14,7 +12,6 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
from .. import config
from ..codegen.wrapper import PythonWrapperCodegen
from ..ir import _IntLike, Layout, TensorBox
from ..utils import load_template
log = logging.getLogger(__name__)
@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
return False
return True
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)

View File

@ -1,13 +1,11 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
from torch._inductor.runtime.triton_compat import tl
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
from torch._inductor.virtualized import V
from torch.utils._triton import has_triton
@ -24,13 +22,11 @@ from ..utils import (
get_num_sms,
has_free_symbols,
use_aten_gemm_kernels,
use_blackwell_cutedsl_grouped_mm,
use_triton_template,
)
from .mm_common import (
_is_static_problem,
check_supported_striding,
load_kernel_template,
persistent_grouped_mm_grid,
)
@ -517,11 +513,6 @@ triton_scaled_grouped_mm_template = TritonTemplate(
source=triton_grouped_mm_source,
)
cutedsl_grouped_mm_template = CuteDSLTemplate(
name="grouped_gemm_cutedsl",
source=load_kernel_template("cutedsl_mm_grouped"),
)
def grouped_mm_args(
mat1: TensorBox,
@ -723,44 +714,43 @@ def _tuned_grouped_mm_common(
# Checking only for the equality of corresponding dims of
# multiplicands here, relying on meta function checks for
# everything else.
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
if (
is_nonzero
and use_triton_template(layout)
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
):
scaled = scale_a is not None
if len(m1_size) == 2:
if len(m2_size) == 2:
m, k1 = m1_size
k2, _ = m2_size
# pyrefly: ignore [missing-attribute]
g = offs.get_size()[0]
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, True
else:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = True, False
else:
if len(m2_size) == 2:
# pyrefly: ignore [missing-attribute]
g1 = offs.layout.size[0]
g2, m, k1 = m1_size
k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, True
else:
g1, m, k1 = m1_size
g2, k2, _ = m2_size
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
V.graph.sizevars.check_equals(k1, k2)
a_is_2d, b_is_2d = False, False
a_is_k_major = mat_a.get_stride()[-1] == 1
b_is_k_major = mat_b.get_stride()[-2] == 1
@ -798,22 +788,6 @@ def _tuned_grouped_mm_common(
**config.kwargs,
)
if use_blackwell_cutedsl_grouped_mm(
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
):
for config in get_groupgemm_configs():
kwargs = dict(
ACC_DTYPE="cutlass.Float32",
)
cutedsl_grouped_mm_template.maybe_append_choice(
choices,
input_nodes=input_nodes,
layout=layout,
**kwargs,
**asdict(config),
)
input_gen_fns = {
4: lambda x: create_offsets(
x, m1_size, m2_size, offs.get_size() if offs is not None else None

View File

@ -1,333 +0,0 @@
import functools
from torch._inductor.runtime.runtime_utils import ceildiv
from cutlass.utils import TensorMapUpdateMode
{{gen_defines()}}
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
GroupedGemmKernel,
)
# Note about caching:
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
# maintains its own local caching system. At this stage, all compile-time
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
# name itself ({{kernel_name}}) are permanently baked into the file, so they
# do not need to be included in any cache key.
#
# The caching mechanism is split into two levels:
#
# 1. prep_cache
# Caches the compiled executor for build_group_ptrs_from_bases(). This
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
# and can therefore be safely reused across runs with different group
# partitioning (`offs`).
#
# 2. gemm_cache
# Caches the compiled Grouped GEMM executor. Its key extends the prep
# cache key with hardware- and grid-specific parameters:
# (prep_cache_key, max_active_clusters, total_num_clusters).
# This is necessary because different `offs` tensors can change the
# per-group problem sizes and thus alter `total_num_clusters`, which in
# turn changes the grid shape and persistent scheduler configuration.
# Kernels compiled for one grid cannot be safely reused for another.
#
#
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
# despite depending only on the GPU type. We cache this function to mitigate
# redundant recompiles even when shape/stride/dtype cache misses force kernel
# regeneration. A follow-up study will investigate the root cause.
prep_cache = {}
gemm_cache = {}
@functools.lru_cache
def get_hardware_info():
hw = cutlass.utils.HardwareInfo()
sm_count = hw.get_max_active_clusters(1)
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
return (sm_count, max_active_clusters)
def get_prep_cache_key(input_a, input_b, output):
"""
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
shapes, strides, and dtypes of input/output tensors.
"""
return (
tuple(input_a.shape),
tuple(input_a.stride()),
input_a.dtype,
tuple(input_b.shape),
tuple(input_b.stride()),
input_b.dtype,
tuple(output.shape),
tuple(output.stride()),
output.dtype,
)
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
"""
Returns a tuple key for caching the gemm kernel executor by extending the
prep cache key with hardware- and grid-specific parameters.
"""
return (
prep_cache_key,
max_active_clusters,
total_num_clusters,
)
@cute.kernel
def build_group_ptrs_from_bases_kernel(
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Int32, # bytes
# -------- STRIDES (in ELEMENTS) --------
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
# -------- OUTPUTS --------
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
):
tidx, _, _ = cute.arch.thread_idx()
g = tidx
m_beg_i32 = 0
if g > 0:
m_beg_i32 = offs[g - 1]
m_end_i32 = offs[g]
m_g_i32 = m_end_i32 - m_beg_i32
a_byte_off = (
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
)
c_byte_off = (
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
)
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
# ---- pointers ----
out_ptrs[g, 0] = base_A_u64 + a_byte_off
out_ptrs[g, 1] = base_B_u64 + b_byte_off
out_ptrs[g, 2] = base_C_u64 + c_byte_off
# ---- (m, n, k, 1) ----
out_problem[g, 0] = m_g_i32
out_problem[g, 1] = N
out_problem[g, 2] = K
out_problem[g, 3] = cutlass.Int32(1)
# ---- strides ----
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
@cute.jit
def launch_build_group_ptrs_from_bases(
base_A_u64: cutlass.Int64,
base_B_u64: cutlass.Int64,
base_C_u64: cutlass.Int64,
offs: cute.Tensor,
G: cutlass.Constexpr,
K: cutlass.Constexpr,
N: cutlass.Constexpr,
sizeof_element: cutlass.Constexpr,
stride_A_m_elems: cutlass.Constexpr,
stride_A_k_elems: cutlass.Constexpr,
stride_B0_elems: cutlass.Constexpr,
stride_Bk_elems: cutlass.Constexpr,
stride_Bn_elems: cutlass.Constexpr,
stride_C_m_elems: cutlass.Constexpr,
stride_C_n_elems: cutlass.Constexpr,
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
out_problem: cute.Tensor, # [G,4] cutlass.Int32
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
stream: cuda.CUstream,
):
build_group_ptrs_from_bases_kernel(
base_A_u64,
base_B_u64,
base_C_u64,
offs,
K,
N,
sizeof_element,
stride_A_m_elems,
stride_A_k_elems,
stride_B0_elems,
stride_Bk_elems,
stride_Bn_elems,
stride_C_m_elems,
stride_C_n_elems,
out_ptrs,
out_problem,
out_strides_abc,
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
{{def_kernel("input_a", "input_b", "input_a_offs")}}
stream = cuda.CUstream(stream)
input_b = input_b.transpose(1, 2)
sumM, K = input_a.shape
G, N, Kb = input_b.shape
dev = input_a.device
base_A_u64 = int(input_a.data_ptr())
base_B_u64 = int(input_b.data_ptr())
base_C_u64 = int({{get_output()}}.data_ptr())
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
ptrs = from_dlpack(ptrs_t)
probs = from_dlpack(probs_t)
strides = from_dlpack(strides_t)
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
prep_executor = prep_cache.get(prep_cache_key)
if prep_executor is None:
sizeof_element = int(input_a.element_size())
sA_m, sA_k = map(int, input_a.stride())
sB_0, sB_n, sB_k = map(int, input_b.stride())
sC_m, sC_n = map(int, {{get_output()}}.stride())
prep_executor = cute.compile(
launch_build_group_ptrs_from_bases,
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
G=int(G),
K=int(K),
N=int(N),
sizeof_element=sizeof_element,
stride_A_m_elems=sA_m,
stride_A_k_elems=sA_k,
stride_B0_elems=sB_0,
stride_Bk_elems=sB_k,
stride_Bn_elems=sB_n,
stride_C_m_elems=sC_m,
stride_C_n_elems=sC_n,
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
prep_cache[prep_cache_key] = prep_executor
prep_executor(
base_A_u64=base_A_u64,
base_B_u64=base_B_u64,
base_C_u64=base_C_u64,
offs=from_dlpack(input_a_offs),
out_ptrs=ptrs,
out_problem=probs,
out_strides_abc=strides,
stream=stream,
)
# --- Tensormap workspace per SM ---
num_tensormap_buffers, max_active_clusters = get_hardware_info()
tensormap_shape = (
num_tensormap_buffers,
GroupedGemmKernel.num_tensormaps,
GroupedGemmKernel.bytes_per_tensormap // 8,
)
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
tensormap_workspace = from_dlpack(tensormap_workspace_t)
# --- Total clusters ---
def compute_total_num_clusters(
problem_sizes_mnkl,
cluster_tile_shape_mn,
):
total_num_clusters = 0
for m, n, _, _ in problem_sizes_mnkl:
num_clusters_mn = tuple(
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
)
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
return total_num_clusters
# Compute cluster tile shape
def compute_cluster_tile_shape(
mma_tiler_mn,
cluster_shape_mn,
use_2cta_instrs,
):
cta_tile_shape_mn = list(mma_tiler_mn)
if use_2cta_instrs:
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
cluster_tile_shape_mn = compute_cluster_tile_shape(
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
)
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
gemm_cache_key = get_gemm_cache_key(
prep_cache_key, max_active_clusters, total_num_clusters
)
gemm_executor = gemm_cache.get(gemm_cache_key)
if gemm_executor is None:
grouped_gemm = GroupedGemmKernel(
acc_dtype=ACC_DTYPE,
use_2cta_instrs=USE_2_CTA,
mma_tiler_mn=(TILE_M, TILE_N),
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
)
gemm_executor = cute.compile(
grouped_gemm,
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
G,
probs,
strides,
ptrs,
total_num_clusters,
tensormap_workspace,
max_active_clusters,
stream,
)
gemm_cache[gemm_cache_key] = gemm_executor
gemm_executor(
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
probs,
strides,
ptrs,
tensormap_workspace,
stream,
)

View File

@ -1,141 +0,0 @@
from dataclasses import dataclass
from enum import auto, Enum
from itertools import product
import torch._inductor.config as config
class TensorMapUpdateMode(Enum):
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
SMEM = auto()
GMEM = auto()
@dataclass(frozen=True)
class CuTeGemmConfig:
TILE_M: int = 128
TILE_N: int = 192
CLUSTER_M: int = 2
CLUSTER_N: int = 1
USE_2_CTA: bool = False
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
For information regarding valid config sets, see:
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
"""
# Tile_n is always the same regardless of 2cta
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
# Valid clusters
clusters_no_2cta = [
(1, 1),
(1, 2),
(1, 4),
(1, 8),
(1, 16),
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
clusters_2cta = [
(2, 1),
(2, 2),
(2, 4),
(2, 8),
(4, 1),
(4, 2),
(4, 4),
(8, 1),
(8, 2),
(16, 1),
]
configs: list[CuTeGemmConfig] = []
for use_2cta, cluster_set, tile_m_range in [
(False, clusters_no_2cta, [64, 128]),
(True, clusters_2cta, [128, 256]),
]:
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
tile_m_range,
tile_n_vals,
cluster_set,
):
configs.append(
CuTeGemmConfig(
tile_m,
tile_n,
cluster_m,
cluster_n,
USE_2_CTA=use_2cta,
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
)
)
return configs
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
"""
config_tuples = [
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
]
return [CuTeGemmConfig(*args) for args in config_tuples]
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
"""
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
or unstable results. By default, autotuning is disabled and we return only
a single baseline config.
"""
if (
config.cutedsl_enable_autotuning
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
):
return get_exhaustive_groupgemm_configs()
elif config.cutedsl_enable_autotuning:
return get_default_groupgemm_configs()
else:
return [get_default_groupgemm_configs()[0]]

View File

@ -1911,84 +1911,6 @@ def use_triton_blackwell_tma_template(
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
@functools.lru_cache(maxsize=1)
def ensure_cute_available() -> bool:
"""Check if CuTeDSL is importable; cache the result for reuse.
Call ensure_cute_available.cache_clear() after installing CuTeDSL
in the same interpreter to retry the import.
"""
try:
return importlib.util.find_spec("cutlass.cute") is not None
except ImportError:
return False
def use_blackwell_cutedsl_grouped_mm(
mat_a: Any,
mat_b: Any,
layout: Layout,
a_is_2d: bool,
b_is_2d: bool,
offs: Optional[Any],
bias: Optional[Any],
scale_result: Optional[Any],
) -> bool:
"""
Returns True if we can use the blackwell kernel for grouped mm.
Required conditions:
1. CuTeDSL backend is enabled
2. CuTeDSL is available
3. We are on a blackwell arch
4. The dtype is bf16
5. Max autotune or max autotune gemm is enabled
6. A, B, and the output are 16B aligned
7. We are not using dynamic shapes
8. A is 2d
9. B is 3d
10. Offsets are provided
11. Bias and Scale are not provided
"""
if not ensure_cute_available():
return False
if not _use_autotune_backend("CUTEDSL"):
return False
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
if not is_gpu(layout.device.type):
return False
if not is_datacenter_blackwell_arch():
return False
layout_dtypes = [torch.bfloat16]
if not _use_template_for_gpu(layout, layout_dtypes):
return False
if not (config.max_autotune or config.max_autotune_gemm):
return False
# Checks for 16B ptr and stride alignment
if not can_use_tma(mat_a, mat_b, output_layout=layout):
return False
if any(is_dynamic(x) for x in [mat_a, mat_b]):
return False
if not a_is_2d or b_is_2d:
return False
if offs is None:
return False
if bias is not None or scale_result is not None:
return False
return True
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
from .virtualized import V

View File

@ -10,7 +10,6 @@ import torch
from ._utils import _device_t, _get_device_index
from .memory import (
empty_cache,
get_memory_info,
max_memory_allocated,
max_memory_reserved,
memory_allocated,
@ -26,10 +25,9 @@ __all__ = [
"current_device_idx", # deprecated
"current_device_index",
"current_stream",
"empty_cache",
"device_count",
"device_index",
"empty_cache",
"get_memory_info",
"is_available",
"max_memory_allocated",
"max_memory_reserved",

View File

@ -8,7 +8,6 @@ from ._utils import _device_t, _get_device_index
__all__ = [
"empty_cache",
"get_memory_info",
"max_memory_allocated",
"max_memory_reserved",
"memory_allocated",
@ -88,9 +87,6 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values.
"""
if not torch._C._accelerator_isAllocatorInitialized():
return OrderedDict()
@ -121,9 +117,6 @@ def memory_allocated(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the current memory occupied by live tensors (in bytes) within the current process.
"""
return memory_stats(device_index).get("allocated_bytes.all.current", 0)
@ -141,9 +134,6 @@ def max_memory_allocated(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the peak memory occupied by live tensors (in bytes) within the current process.
"""
return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
@ -157,9 +147,6 @@ def memory_reserved(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the current memory reserved by PyTorch (in bytes) within the current process.
"""
return memory_stats(device_index).get("reserved_bytes.all.current", 0)
@ -177,9 +164,6 @@ def max_memory_reserved(device_index: _device_t = None, /) -> int:
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
int: the peak memory reserved by PyTorch (in bytes) within the current process.
"""
return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
@ -216,21 +200,3 @@ def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
"""
device_index = _get_device_index(device_index, optional=True)
return torch._C._accelerator_resetPeakStats(device_index)
def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]:
r"""Return the current device memory information for a given device index.
Args:
device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
If not given, use :func:`torch.accelerator.current_device_index` by default.
If a :class:`torch.device` or str is provided, its type must match the current
:ref:`accelerator<accelerators>` device type.
Returns:
tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes.
The first value is the free memory on the device (available across all processes and applications),
The second value is the device's total hardware memory capacity.
"""
device_index = _get_device_index(device_index, optional=True)
return torch._C._accelerator_getMemoryInfo(device_index)

View File

@ -195,12 +195,10 @@ def get_new_attr_name_with_prefix(prefix: str) -> Callable:
def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
r"""Starting from a target node, trace back until we hit input or
getattr node. This is used to extract the chain of operators
starting from getattr to the target node, for example::
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
starting from getattr to the target node, for example
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
collect_producer_nodes(observed) will either return a list of nodes that
produces the observed node or None if we can't extract a self contained
graph without free variables(inputs of the forward function).

View File

@ -138,13 +138,6 @@ void initModule(PyObject* module) {
at::accelerator::resetPeakStats(device_index);
});
m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
py::gil_scoped_release no_gil;
return at::accelerator::getMemoryInfo(device_index);
});
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
c10::CachingAllocator::setAllocatorSettings(env);
});

View File

@ -386,8 +386,23 @@ static void bindGetDeviceProperties(PyObject* module) {
static void initXpuMethodBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
py::gil_scoped_release no_gil;
return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index);
#if SYCL_COMPILER_VERSION >= 20250000
auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size;
auto& device = c10::xpu::get_raw_device(device_index);
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
at::xpu::getDeviceProperties(device_index)->name,
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
auto free = device.get_info<sycl::ext::intel::info::device::free_memory>();
return std::make_tuple(free, total);
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
#endif
});
m.def(
"_xpu_getStreamFromExternal",

View File

@ -1,5 +1,4 @@
import itertools
import math
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, cast, NamedTuple, Optional
@ -8,7 +7,6 @@ import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
_StridedShard,
MaskPartial,
Partial,
Placement,
Replicate,
@ -129,185 +127,6 @@ class DTensorSpec:
)
return default_shard_order
@staticmethod
def _convert_shard_order_to_StridedShard(
shard_order: ShardOrder, placements: tuple[Placement, ...], mesh: DeviceMesh
) -> tuple[Placement, ...]:
"""
Convert ShardOrder to placements with _StridedShard.
This function converts a ShardOrder specification into a tuple of Placement objects,
using _StridedShard when a tensor dimension is sharded across multiple mesh dimensions
in a non-default order. The split_factor of each _StridedShard is determined by the
product of mesh dimension sizes that appear earlier in the shard order but later in
the placement tuple.
Args:
shard_order: ShardOrder specification indicating which tensor dimensions are
sharded on which mesh dimensions and in what execution order.
placements: Tuple of Placement objects that does not contain _StridedShard.
mesh: DeviceMesh containing the size information for each mesh dimension.
Returns:
Updated tuple of Placement objects with Shard or _StridedShard placements.
Algorithm:
For each ShardOrderEntry in shard_order:
- For each mesh dimension in the entry's mesh_dims (in order):
- Calculate split_factor as the product of mesh sizes for all mesh dimensions
that appear:
1. Earlier in the shard order (lower index in mesh_dims), and
2. Later in the placement tuple (higher mesh dimension index)
- If split_factor == 1: use normal Shard
- Otherwise: use _StridedShard with the calculated split_factor
Example:
>>> # xdoctest: +SKIP("Requires DeviceMesh")
>>> # Tensor dimension 0 sharded on mesh dims [2, 0, 1] in that order
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
>>> shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),)
>>> placements = (Shard(0), Shard(0), Shard(0))
>>> # For mesh_dim=2 (index 0 in mesh_dims): no earlier dims, split_factor=1
>>> # -> placements[2] = Shard(0)
>>> # For mesh_dim=0 (index 1 in mesh_dims): mesh_dim=2 is earlier and has index 2>0
>>> # -> split_factor = mesh.size(2) = 2
>>> # -> placements[0] = _StridedShard(0, split_factor=2)
>>> # For mesh_dim=1 (index 2 in mesh_dims): mesh_dim=2 is earlier and has index 2>1
>>> # -> split_factor = mesh.size(2) = 2
>>> # -> placements[1] = _StridedShard(0, split_factor=2)
>>> # Result: (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
"""
placements_list = list(placements)
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for idx in range(len(mesh_dims)):
# TODO(zpcore): split_factor from `view` and `shard order`
# should be able to be multiplied into one. Need to loosen the
# condition here.
mesh_dim = mesh_dims[idx]
if type(placements[mesh_dim]) is not Shard:
raise ValueError(
f"Only Shard placement can be converted to _StridedShard, "
f"found {placements[mesh_dim]} in {placements=}."
)
split_factor = math.prod(
mesh.size(i) for i in mesh_dims[:idx] if i > mesh_dim
)
if split_factor == 1:
# use normal Shard
placements_list[mesh_dim] = Shard(tensor_dim)
else:
placements_list[mesh_dim] = _StridedShard(
tensor_dim, split_factor=split_factor
)
return tuple(placements_list)
@staticmethod
def _maybe_convert_StridedShard_to_shard_order(
placements: tuple[Placement, ...], mesh: DeviceMesh
) -> Optional[ShardOrder]:
"""
Try to convert _StridedShard placements to ShardOrder.
This is the inverse of `_convert_shard_order_to_StridedShard`. It reconstructs the shard
order by examining the split_factor of each _StridedShard and determining its position
in the execution order. If the _StridedShard configuration cannot be represented as a
valid ShardOrder (i.e., there's no shard order that produces the observed split_factors),
this function returns None.
Args:
placements: Tuple of Placement objects that may contain _StridedShard.
mesh: DeviceMesh containing the size information for each mesh dimension.
Returns:
ShardOrder if conversion is possible, None otherwise. For placements without
_StridedShard, returns the default shard order.
Algorithm:
1. If no _StridedShard in placements, return default shard order
2. Create an empty list for each tensor dimension to represent mesh dim ordering
3. Iterate through placements in reverse order (right to left):
- For each Shard/_StridedShard on a tensor dimension:
- Extract its split_factor (1 for Shard, split_factor for _StridedShard)
- Find the position in mesh_dims_order where accumulated_sf equals split_factor
- accumulated_sf is the product of mesh sizes of mesh dimensions that appear
earlier in mesh_dims_order (lower indices)
- Insert mesh_dim at the found position
4. If no valid position found for any split_factor, return None (unable to convert)
5. Construct ShardOrderEntry for each tensor dimension from mesh_dims_order
Example:
>>> # xdoctest: +SKIP("Requires DeviceMesh")
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
>>> # placements = (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
>>> # Process tensor_dim=0 from right to left:
>>> # - mesh_dim=2: Shard(0) with sf=1
>>> # Try position 0: accumulated_sf=1, matches! Insert at position 0
>>> # Current mesh_dims_order order: [2]
>>> # - mesh_dim=1: _StridedShard(0, sf=2) with sf=2
>>> # Try position 0: accumulated_sf=1, no match
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
>>> # Current mesh_dims_order order: [2, 1]
>>> # - mesh_dim=0: _StridedShard(0, sf=2) with sf=2
>>> # Try position 0: accumulated_sf=1, no match
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
>>> # Final mesh_dims_order order: [2, 0, 1]
>>> # Result: ShardOrder((ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),))
>>> # This means: first shard on mesh_dim=2, then mesh_dim=0, then mesh_dim=1
Note:
This function validates that _StridedShard can be represented as a ShardOrder.
Not all _StridedShard configurations are valid - the split_factor must match
the product of mesh sizes in some execution order.
"""
if not any(isinstance(p, _StridedShard) for p in placements):
return DTensorSpec.compute_default_shard_order(placements)
max_tensor_dim = (
max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1
)
shard_order = []
tensor_dim_to_mesh_dims_order: list[list[int]] = [
[] for i in range(max_tensor_dim)
]
for mesh_dim in reversed(range(len(placements))):
cur_placement = placements[mesh_dim]
# _StridedShard may not be a subclass of Shard in the future, so write in this way:
if isinstance(cur_placement, Shard | _StridedShard):
tensor_dim = cur_placement.dim
mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim]
cur_sf = 1
if isinstance(cur_placement, _StridedShard):
cur_sf = cur_placement.split_factor
accumulated_sf = 1
find_order = False
for i in range(len(mesh_dims_order) + 1):
if accumulated_sf == cur_sf:
mesh_dims_order.insert(i, mesh_dim)
find_order = True
break
if i < len(mesh_dims_order):
accumulated_sf *= mesh.size(mesh_dims_order[i])
if not find_order:
# _StridedShard is not convertible to ShardOrder
return None
else:
if not isinstance(cur_placement, Replicate | Partial | MaskPartial):
raise ValueError(
f"Unsupported placement type {type(cur_placement)} encountered in "
f"{placements}; expected Replicate, Partial, or MaskPartial."
)
for tensor_dim in range(max_tensor_dim):
if len(tensor_dim_to_mesh_dims_order[tensor_dim]) > 0:
shard_order.append(
ShardOrderEntry(
tensor_dim=tensor_dim,
mesh_dims=tuple(tensor_dim_to_mesh_dims_order[tensor_dim]),
)
)
return tuple(shard_order)
def _verify_shard_order(self, shard_order: ShardOrder) -> None:
"""Verify that the shard_order is valid and matches the placements."""
total_shard = 0

View File

@ -11,10 +11,7 @@ import torch.distributed.tensor._api as dtensor
aten = torch.ops.aten
def _requires_data_exchange(padding, dim_map) -> bool:
# Data exchange is not need if only sharded across batch dim
if all(x == -1 for x in dim_map[1:]):
return False
def _requires_data_exchange(padding):
# TODO: whether there requires data exchange is currently determined by padding
return padding[-1] != 0
@ -110,7 +107,6 @@ def tp_convolution(
op_call: torch._ops.OpOverload,
local_tensor_args: tuple[object, ...],
local_tensor_kwargs: dict[str, object],
dim_map: list[int],
) -> object:
assert op_call == aten.convolution.default
assert len(local_tensor_args) == 9
@ -124,7 +120,7 @@ def tp_convolution(
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, list)
if not _requires_data_exchange(padding, dim_map):
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
@ -164,7 +160,6 @@ def tp_convolution_backward(
op_call: torch._ops.OpOverload,
local_tensor_args: tuple[object, ...],
local_tensor_kwargs: dict[str, object],
dim_map: list[int],
) -> object:
assert op_call == aten.convolution_backward.default
assert len(local_tensor_args) == 11
@ -179,7 +174,7 @@ def tp_convolution_backward(
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, list)
if not _requires_data_exchange(padding, dim_map):
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
@ -244,18 +239,15 @@ def convolution_handler(
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
output_spec = output_sharding.output_spec
assert isinstance(output_spec, dtensor.DTensorSpec)
# local propagation
local_results = tp_convolution(
op_call,
tuple(op_info.local_args),
op_info.local_kwargs,
output_spec.dim_map,
op_call, tuple(op_info.local_args), op_info.local_kwargs
)
return dtensor.DTensor._op_dispatcher.wrap(local_results, output_spec)
return dtensor.DTensor._op_dispatcher.wrap(
local_results, output_sharding.output_spec
)
def convolution_backward_handler(
@ -278,14 +270,10 @@ def convolution_backward_handler(
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
assert isinstance(op_info.flat_args_schema[0], dtensor.DTensorSpec)
# local propagation
local_results = tp_convolution_backward(
op_call,
tuple(op_info.local_args),
op_info.local_kwargs,
op_info.flat_args_schema[0].dim_map,
op_call, tuple(op_info.local_args), op_info.local_kwargs
)
return dtensor.DTensor._op_dispatcher.wrap(

View File

@ -20320,7 +20320,6 @@ op_db: list[OpInfo] = [
torch.float32: 1e-4}),),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
supports_sparse=True,
supports_sparse_csr=True,
supports_sparse_csc=True,
supports_sparse_bsr=True,

View File

@ -3,7 +3,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import copy
import functools
import itertools
import sys
@ -33,8 +32,6 @@ from torch.distributed.tensor import (
Replicate,
Shard,
)
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
@ -821,125 +818,3 @@ def map_local_for_rank(rank, func):
def reduce_local_int(val, func):
return func(val.node._local_ints)
def _convert_shard_order_dict_to_ShardOrder(shard_order):
"""Convert shard_order dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
for tensor_dim, mesh_dims in shard_order.items()
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def redistribute(
dtensor_input,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""
wrapper function to support shard_order for redistribution
This is a simpler version of Redistribute, only considers the forward.
"""
if placements is None:
placements = shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
old_spec = dtensor_input._spec
new_spec = copy.deepcopy(old_spec)
new_spec.placements = placements
if shard_order is not None:
new_spec.shard_order = shard_order
else:
new_spec.shard_order = ()
if old_spec == new_spec:
return dtensor_input
dtensor_input = DTensor.from_local(
redistribute_local_tensor(
dtensor_input.to_local(),
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)
return dtensor_input # returns DTensor
# TODO(zpcore): remove once the native distribute_tensor supports
# shard_order arg
def patched_distribute_tensor(
input_tensor,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""wrapper function to support shard_order for tensor distribution"""
if placements is None:
placements = shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
# fix the shard order
return redistribute(
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def make_full_tensor(dtensor_input):
"""wrapper function to support DTensor.full_tensor"""
return redistribute(
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
).to_local()
def shard_order_to_placement(shard_order, mesh):
"""convert shard_order to placement with only Replicate() and Shard()"""
placements: list[Any] = [Replicate() for _ in range(mesh.ndim)]
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
def generate_shard_orders(mesh, tensor_rank):
# Generate all possible sharding placement of tensor with rank
# `tensor_rank` over mesh.
def _split_list(lst: list, N: int):
def compositions(n: int, k: int):
# yields lists of length k, positive ints summing to n
for cuts in itertools.combinations(range(1, n), k - 1):
# add 0 and n as sentinels, then take consecutive differences
yield [b - a for a, b in itertools.pairwise((0, *cuts, n))]
length = len(lst)
for comp in compositions(length, N):
result = []
start = 0
for size in comp:
result.append(lst[start : start + size])
start += size
yield result
all_mesh = list(range(mesh.ndim))
all_device_order = list(itertools.permutations(all_mesh))
for device_order in all_device_order:
# split on device orders, and assign each device order segment to a tensor dim
for num_split in range(1, mesh.ndim + 1):
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
for tensor_dims in itertools.combinations(
range(tensor_rank), len(splitted_list)
):
shard_order = {}
assert len(tensor_dims) == len(splitted_list)
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
shard_order[tensor_dim] = device_order[
mesh_dims[0] : mesh_dims[-1] + 1
]
yield _convert_shard_order_dict_to_ShardOrder(shard_order)

View File

@ -529,14 +529,11 @@ RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+
def replace_extern_shared(input_string):
"""
Match 'extern __shared__ type foo[];' syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
See: https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
Examples:
"extern __shared__ char smemChar[];"
=> "HIP_DYNAMIC_SHARED( char, smemChar)"
"extern __shared__ unsigned char smem[];"
=> "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
"""Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
Example:
"extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
"extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
"""
output_string = input_string
output_string = RE_EXTERN_SHARED.sub(
@ -1046,17 +1043,14 @@ RE_INCLUDE = re.compile(r"#include .*\n")
def extract_arguments(start, string):
"""
Return the list of arguments in the upcoming function parameter closure.
Example:
""" Return the list of arguments in the upcoming function parameter closure.
Example:
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
arguments (output):
[
{'start': 1, 'end': 7},
{'start': 8, 'end': 16},
{'start': 17, 'end': 19},
{'start': 20, 'end': 53}
]
'[{'start': 1, 'end': 7},
{'start': 8, 'end': 16},
{'start': 17, 'end': 19},
{'start': 20, 'end': 53}]'
"""
arguments = []

View File

@ -190,7 +190,6 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
int: the memory available on the device in units of bytes.
int: the total memory on the device in units of bytes
"""
_lazy_init()
device = _get_device_index(device, optional=True)
return torch._C._xpu_getMemoryInfo(device)