mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 14:54:55 +08:00
Compare commits
5 Commits
documentat
...
gh/zpcore/
| Author | SHA1 | Date | |
|---|---|---|---|
| f3d17e14ec | |||
| d5bc6fdce6 | |||
| b80c03d695 | |||
| f82985890f | |||
| 4ac82f37b9 |
@ -260,8 +260,8 @@ case "$tag" in
|
||||
HALIDE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
|
||||
CUDA_VERSION=12.8.1
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
|
||||
CUDA_VERSION=13.0.0
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=11
|
||||
PALLAS=yes
|
||||
|
||||
@ -168,16 +168,14 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/compiler/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/umf/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Enable XCCL build
|
||||
export USE_XCCL=1
|
||||
export USE_MPI=0
|
||||
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
|
||||
export USE_KINETO=0
|
||||
export TORCH_XPU_ARCH_LIST=pvc
|
||||
fi
|
||||
|
||||
|
||||
@ -208,8 +208,6 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Check XPU status before testing
|
||||
timeout 30 xpu-smi discovery || true
|
||||
fi
|
||||
@ -339,7 +337,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
||||
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
|
||||
|
||||
2
.github/workflows/docker-builds.yml
vendored
2
.github/workflows/docker-builds.yml
vendored
@ -67,7 +67,7 @@ jobs:
|
||||
pytorch-linux-jammy-py3.10-gcc11,
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
|
||||
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-noble-xpu-n-py3,
|
||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
||||
|
||||
4
.github/workflows/inductor-unittest.yml
vendored
4
.github/workflows/inductor-unittest.yml
vendored
@ -86,8 +86,8 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.8xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,7 +127,6 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
@ -94,11 +94,6 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -157,8 +157,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
||||
DispatchKey::Negative,
|
||||
DispatchKey::Conjugate,
|
||||
DispatchKey::XLA,
|
||||
DispatchKey::XPU,
|
||||
DispatchKey::HPU,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
|
||||
@ -4292,7 +4292,6 @@
|
||||
dispatch:
|
||||
SparseCPU: sparse_sparse_matmul_cpu
|
||||
SparseCUDA: sparse_sparse_matmul_cuda
|
||||
SparseMPS: sparse_sparse_matmul_mps
|
||||
autogen: _sparse_sparse_matmul.out
|
||||
|
||||
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
@ -9833,7 +9832,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method, function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse
|
||||
SparseCPU, SparseCUDA: erfinv_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
|
||||
tags: pointwise
|
||||
|
||||
@ -9842,7 +9841,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
|
||||
tags: pointwise
|
||||
|
||||
@ -9852,7 +9851,7 @@
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: erfinv_out
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_out
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
|
||||
tags: pointwise
|
||||
|
||||
|
||||
@ -10,10 +10,6 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_coalesce_native.h>
|
||||
#include <ATen/ops/repeat_interleave_native.h>
|
||||
#include <ATen/ops/cumsum.h>
|
||||
#include <ATen/ops/_sparse_sparse_matmul_native.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/add_native.h>
|
||||
@ -892,114 +888,5 @@ static void sparse_mask_intersection_out_mps_kernel(
|
||||
/*coalesce_mask=*/false);
|
||||
}
|
||||
|
||||
Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be sparse COO tensors");
|
||||
TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be on MPS device");
|
||||
TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2,
|
||||
"sparse_sparse_matmul_mps: both inputs must be 2D matrices");
|
||||
TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0,
|
||||
"sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)");
|
||||
TORCH_CHECK(mat1_.size(1) == mat2_.size(0),
|
||||
"mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
|
||||
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
|
||||
"sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(),
|
||||
" does not match mat2 dtype ", mat2_.scalar_type());
|
||||
|
||||
const auto device = mat1_.device();
|
||||
|
||||
auto A = mat1_.coalesce();
|
||||
auto B = mat2_.coalesce();
|
||||
|
||||
const auto I = A.size(0);
|
||||
const auto K = A.size(1);
|
||||
const auto N = B.size(1);
|
||||
|
||||
const auto nnzA = A._nnz();
|
||||
const auto nnzB = B._nnz();
|
||||
|
||||
// Early empty result, return an empty, coalesced tensor
|
||||
if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
const auto computeDtype = at::result_type(mat1_, mat2_);
|
||||
|
||||
auto A_idx = A._indices().contiguous();
|
||||
auto A_val = A._values().to(computeDtype).contiguous();
|
||||
auto A_i = A_idx.select(0, 0).contiguous();
|
||||
auto A_k = A_idx.select(0, 1).contiguous();
|
||||
|
||||
auto B_idx = B._indices().contiguous();
|
||||
auto B_val = B._values().to(computeDtype).contiguous();
|
||||
auto B_k = B_idx.select(0, 0).contiguous();
|
||||
auto B_j = B_idx.select(0, 1).contiguous();
|
||||
|
||||
// csr-style row pointers for B by k (the shared dimension)
|
||||
Tensor row_ptr_B;
|
||||
{
|
||||
auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong));
|
||||
row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong));
|
||||
build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B);
|
||||
}
|
||||
|
||||
auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K);
|
||||
auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K);
|
||||
auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo);
|
||||
|
||||
auto counts = deg_B.index_select(0, A_k);
|
||||
|
||||
const int64_t P = counts.sum().item<int64_t>();
|
||||
if (P == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
auto group_ids = repeat_interleave_mps(counts);
|
||||
|
||||
// exclusive cumsum of counts
|
||||
auto offsets = cumsum(counts, /*dim=*/0).sub(counts);
|
||||
auto offsets_gather = offsets.index_select(0, group_ids);
|
||||
auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather);
|
||||
|
||||
// Map each output element to its source B row and position
|
||||
auto k_per_out = A_k.index_select(0, group_ids);
|
||||
auto start_in_B = row_ptr_B.index_select(0, k_per_out);
|
||||
auto seg_index = start_in_B.add(within);
|
||||
|
||||
// Assemble candidate coo pairs and values
|
||||
auto i_out = A_i.index_select(0, group_ids).contiguous();
|
||||
auto j_out = B_j.index_select(0, seg_index).contiguous();
|
||||
auto vA_out = A_val.index_select(0, group_ids).contiguous();
|
||||
auto vB_out = B_val.index_select(0, seg_index).contiguous();
|
||||
auto v_out = vA_out.mul(vB_out);
|
||||
|
||||
// build (2, P) indices
|
||||
auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous();
|
||||
out_indices.select(0, 0).copy_(i_out);
|
||||
out_indices.select(0, 1).copy_(j_out);
|
||||
|
||||
auto result = _sparse_coo_tensor_unsafe(
|
||||
out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype));
|
||||
|
||||
result = result.coalesce();
|
||||
|
||||
if (result.scalar_type() != mat1_.scalar_type()) {
|
||||
auto cast_vals = result._values().to(mat1_.scalar_type());
|
||||
auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -96,10 +96,6 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -345,13 +345,6 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<AllocatorState> pps) = 0;
|
||||
virtual std::string name() = 0;
|
||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return {free, total};
|
||||
}
|
||||
};
|
||||
|
||||
// Allocator object, statically initialized
|
||||
|
||||
@ -926,14 +926,15 @@ class DeviceCachingAllocator {
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
const auto device_total =
|
||||
raw_device.get_info<sycl::info::device::global_mem_size>();
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
// Estimate the available device memory when the SYCL runtime does not
|
||||
// support the corresponding aspect (ext_intel_free_memory).
|
||||
size_t device_free = device_total -
|
||||
size_t device_free = device_prop.global_mem_size -
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||
// affected devices.
|
||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||
@ -1051,37 +1052,21 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo() {
|
||||
const auto& device = c10::xpu::get_raw_device(device_index);
|
||||
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
device.get_info<sycl::info::device::name>(),
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
const size_t free =
|
||||
device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_total);
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||
set_fraction = true;
|
||||
}
|
||||
@ -1255,11 +1240,6 @@ class XPUAllocator : public DeviceAllocator {
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryInfo();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
|
||||
@ -40,7 +40,6 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_memory_info
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
memory_allocated
|
||||
|
||||
@ -382,6 +382,20 @@ coverage_ignore_functions = [
|
||||
# torch.ao.quantization.backend_config.tensorrt
|
||||
"get_tensorrt_backend_config",
|
||||
"get_tensorrt_backend_config_dict",
|
||||
# torch.ao.quantization.backend_config.utils
|
||||
"entry_to_pretty_str",
|
||||
"get_fused_module_classes",
|
||||
"get_fuser_method_mapping",
|
||||
"get_fusion_pattern_to_extra_inputs_getter",
|
||||
"get_fusion_pattern_to_root_node_getter",
|
||||
"get_module_to_qat_module",
|
||||
"get_pattern_to_dtype_configs",
|
||||
"get_pattern_to_input_type_to_index",
|
||||
"get_qat_module_classes",
|
||||
"get_root_module_to_quantized_reference_module",
|
||||
"pattern_to_human_readable",
|
||||
"remove_boolean_dispatch_from_name",
|
||||
# torch.ao.quantization.backend_config.x86
|
||||
"get_x86_backend_config",
|
||||
# torch.ao.quantization.fuse_modules
|
||||
"fuse_known_modules",
|
||||
@ -412,6 +426,25 @@ coverage_ignore_functions = [
|
||||
"insert_observers_for_model",
|
||||
"prepare",
|
||||
"propagate_dtypes_for_known_nodes",
|
||||
# torch.ao.quantization.fx.utils
|
||||
"all_node_args_except_first",
|
||||
"all_node_args_have_no_tensors",
|
||||
"assert_and_get_unique_device",
|
||||
"collect_producer_nodes",
|
||||
"create_getattr_from_value",
|
||||
"create_node_from_old_node_preserve_meta",
|
||||
"get_custom_module_class_keys",
|
||||
"get_linear_prepack_op_for_dtype",
|
||||
"get_new_attr_name_with_prefix",
|
||||
"get_non_observable_arg_indexes_and_types",
|
||||
"get_qconv_prepack_op",
|
||||
"get_skipped_module_name_and_classes",
|
||||
"graph_module_from_producer_nodes",
|
||||
"maybe_get_next_module",
|
||||
"node_arg_is_bias",
|
||||
"node_arg_is_weight",
|
||||
"return_arg_list",
|
||||
# torch.ao.quantization.pt2e.graph_utils
|
||||
"bfs_trace_with_node_process",
|
||||
"find_sequential_partitions",
|
||||
"get_equivalent_types",
|
||||
@ -827,10 +860,80 @@ coverage_ignore_functions = [
|
||||
"get_latency_of_one_partition",
|
||||
"get_latency_of_partitioned_graph",
|
||||
"get_partition_to_latency_mapping",
|
||||
# torch.fx.experimental.proxy_tensor
|
||||
"decompose",
|
||||
"disable_autocast_cache",
|
||||
"disable_proxy_modes_tracing",
|
||||
"dispatch_trace",
|
||||
"extract_val",
|
||||
"fake_signature",
|
||||
"fetch_sym_proxy",
|
||||
"fetch_object_proxy",
|
||||
"get_innermost_proxy_mode",
|
||||
"get_isolated_graphmodule",
|
||||
"get_proxy_slot",
|
||||
"get_torch_dispatch_modes",
|
||||
"has_proxy_slot",
|
||||
"is_sym_node",
|
||||
"maybe_handle_decomp",
|
||||
"proxy_call",
|
||||
"set_meta",
|
||||
"set_original_aten_op",
|
||||
"set_proxy_slot",
|
||||
"snapshot_fake",
|
||||
"thunkify",
|
||||
"track_tensor",
|
||||
"track_tensor_tree",
|
||||
"wrap_key",
|
||||
"wrapper_and_args_for_make_fx",
|
||||
# torch.fx.experimental.recording
|
||||
"record_shapeenv_event",
|
||||
"replay_shape_env_events",
|
||||
"shape_env_check_state_equal",
|
||||
# torch.fx.experimental.sym_node
|
||||
"ceil_impl",
|
||||
"floor_ceil_helper",
|
||||
"floor_impl",
|
||||
"method_to_operator",
|
||||
"sympy_is_channels_last_contiguous_2d",
|
||||
"sympy_is_channels_last_contiguous_3d",
|
||||
"sympy_is_channels_last_strides_2d",
|
||||
"sympy_is_channels_last_strides_3d",
|
||||
"sympy_is_channels_last_strides_generic",
|
||||
"sympy_is_contiguous",
|
||||
"sympy_is_contiguous_generic",
|
||||
"to_node",
|
||||
"wrap_node",
|
||||
"sym_sqrt",
|
||||
# torch.fx.experimental.symbolic_shapes
|
||||
"bind_symbols",
|
||||
"cast_symbool_to_symint_guardless",
|
||||
"create_contiguous",
|
||||
"error",
|
||||
"eval_guards",
|
||||
"eval_is_non_overlapping_and_dense",
|
||||
"expect_true",
|
||||
"find_symbol_binding_fx_nodes",
|
||||
"free_symbols",
|
||||
"free_unbacked_symbols",
|
||||
"fx_placeholder_targets",
|
||||
"fx_placeholder_vals",
|
||||
"guard_bool",
|
||||
"guard_float",
|
||||
"guard_int",
|
||||
"guard_scalar",
|
||||
"has_hint",
|
||||
"has_symbolic_sizes_strides",
|
||||
"is_channels_last_contiguous_2d",
|
||||
"is_channels_last_contiguous_3d",
|
||||
"is_channels_last_strides_2d",
|
||||
"is_channels_last_strides_3d",
|
||||
"is_contiguous",
|
||||
"is_non_overlapping_and_dense_indicator",
|
||||
"is_nested_int",
|
||||
"is_symbol_binding_fx_node",
|
||||
"is_symbolic",
|
||||
# torch.fx.experimental.unification.core
|
||||
"reify",
|
||||
# torch.fx.experimental.unification.match
|
||||
"edge",
|
||||
@ -868,6 +971,24 @@ coverage_ignore_functions = [
|
||||
"reverse_dict",
|
||||
# torch.fx.experimental.unification.multipledispatch.variadic
|
||||
"isvariadic",
|
||||
# torch.fx.experimental.unification.unification_tools
|
||||
"assoc",
|
||||
"assoc_in",
|
||||
"dissoc",
|
||||
"first",
|
||||
"get_in",
|
||||
"getter",
|
||||
"groupby",
|
||||
"itemfilter",
|
||||
"itemmap",
|
||||
"keyfilter",
|
||||
"keymap",
|
||||
"merge",
|
||||
"merge_with",
|
||||
"update_in",
|
||||
"valfilter",
|
||||
"valmap",
|
||||
# torch.fx.experimental.unification.utils
|
||||
"freeze",
|
||||
"hashable",
|
||||
"raises",
|
||||
@ -1308,8 +1429,319 @@ coverage_ignore_functions = [
|
||||
# torch.onnx.symbolic_opset7
|
||||
"max",
|
||||
"min",
|
||||
# torch.onnx.symbolic_opset8
|
||||
"addmm",
|
||||
"bmm",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"flatten",
|
||||
"full",
|
||||
"full_like",
|
||||
"gt",
|
||||
"lt",
|
||||
"matmul",
|
||||
"mm",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"prelu",
|
||||
"repeat",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.symbolic_opset9
|
||||
"abs",
|
||||
"acos",
|
||||
"adaptive_avg_pool1d",
|
||||
"adaptive_avg_pool2d",
|
||||
"adaptive_avg_pool3d",
|
||||
"adaptive_max_pool1d",
|
||||
"adaptive_max_pool2d",
|
||||
"adaptive_max_pool3d",
|
||||
"add",
|
||||
"addcmul",
|
||||
"addmm",
|
||||
"alias",
|
||||
"amax",
|
||||
"amin",
|
||||
"aminmax",
|
||||
"arange",
|
||||
"argmax",
|
||||
"argmin",
|
||||
"as_strided",
|
||||
"as_tensor",
|
||||
"asin",
|
||||
"atan",
|
||||
"atan2",
|
||||
"avg_pool1d",
|
||||
"avg_pool2d",
|
||||
"avg_pool3d",
|
||||
"baddbmm",
|
||||
"batch_norm",
|
||||
"bernoulli",
|
||||
"bitwise_not",
|
||||
"bitwise_or",
|
||||
"bmm",
|
||||
"broadcast_tensors",
|
||||
"broadcast_to",
|
||||
"bucketize",
|
||||
"cat",
|
||||
"cdist",
|
||||
"ceil",
|
||||
"clamp",
|
||||
"clamp_max",
|
||||
"clamp_min",
|
||||
"clone",
|
||||
"constant_pad_nd",
|
||||
"contiguous",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"conv_tbc",
|
||||
"conv_transpose1d",
|
||||
"conv_transpose2d",
|
||||
"conv_transpose3d",
|
||||
"convert_element_type",
|
||||
"convolution",
|
||||
"cos",
|
||||
"cosine_similarity",
|
||||
"cross",
|
||||
"cumsum",
|
||||
"detach",
|
||||
"dim",
|
||||
"div",
|
||||
"dot",
|
||||
"dropout",
|
||||
"elu",
|
||||
"embedding",
|
||||
"embedding_bag",
|
||||
"empty",
|
||||
"empty_like",
|
||||
"eq",
|
||||
"erf",
|
||||
"exp",
|
||||
"expand",
|
||||
"expand_as",
|
||||
"eye",
|
||||
"fill",
|
||||
"flatten",
|
||||
"floor",
|
||||
"floor_divide",
|
||||
"floordiv",
|
||||
"frobenius_norm",
|
||||
"full",
|
||||
"full_like",
|
||||
"gather",
|
||||
"ge",
|
||||
"gelu",
|
||||
"get_pool_ceil_padding",
|
||||
"glu",
|
||||
"group_norm",
|
||||
"gru",
|
||||
"gt",
|
||||
"hann_window",
|
||||
"hardshrink",
|
||||
"hardsigmoid",
|
||||
"hardswish",
|
||||
"hardtanh",
|
||||
"index",
|
||||
"index_add",
|
||||
"index_copy",
|
||||
"index_fill",
|
||||
"index_put",
|
||||
"index_select",
|
||||
"instance_norm",
|
||||
"is_floating_point",
|
||||
"is_pinned",
|
||||
"isnan",
|
||||
"item",
|
||||
"kl_div",
|
||||
"layer_norm",
|
||||
"le",
|
||||
"leaky_relu",
|
||||
"lerp",
|
||||
"lift",
|
||||
"linalg_cross",
|
||||
"linalg_matrix_norm",
|
||||
"linalg_norm",
|
||||
"linalg_vector_norm",
|
||||
"linear",
|
||||
"linspace",
|
||||
"log",
|
||||
"log10",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log_sigmoid",
|
||||
"log_softmax",
|
||||
"logical_and",
|
||||
"logical_not",
|
||||
"logical_or",
|
||||
"logical_xor",
|
||||
"logit",
|
||||
"logsumexp",
|
||||
"lstm",
|
||||
"lstm_cell",
|
||||
"lt",
|
||||
"masked_fill",
|
||||
"masked_fill_",
|
||||
"matmul",
|
||||
"max",
|
||||
"max_pool1d",
|
||||
"max_pool1d_with_indices",
|
||||
"max_pool2d",
|
||||
"max_pool2d_with_indices",
|
||||
"max_pool3d",
|
||||
"max_pool3d_with_indices",
|
||||
"maximum",
|
||||
"meshgrid",
|
||||
"min",
|
||||
"minimum",
|
||||
"mish",
|
||||
"mm",
|
||||
"movedim",
|
||||
"mse_loss",
|
||||
"mul",
|
||||
"multinomial",
|
||||
"mv",
|
||||
"narrow",
|
||||
"native_layer_norm",
|
||||
"ne",
|
||||
"neg",
|
||||
"new_empty",
|
||||
"new_full",
|
||||
"new_ones",
|
||||
"new_zeros",
|
||||
"nonzero",
|
||||
"nonzero_numpy",
|
||||
"noop_complex_operators",
|
||||
"norm",
|
||||
"numel",
|
||||
"numpy_T",
|
||||
"one_hot",
|
||||
"ones",
|
||||
"ones_like",
|
||||
"onnx_placeholder",
|
||||
"overload_by_arg_count",
|
||||
"pad",
|
||||
"pairwise_distance",
|
||||
"permute",
|
||||
"pixel_shuffle",
|
||||
"pixel_unshuffle",
|
||||
"pow",
|
||||
"prelu",
|
||||
"prim_constant",
|
||||
"prim_constant_chunk",
|
||||
"prim_constant_split",
|
||||
"prim_data",
|
||||
"prim_device",
|
||||
"prim_dtype",
|
||||
"prim_if",
|
||||
"prim_layout",
|
||||
"prim_list_construct",
|
||||
"prim_list_unpack",
|
||||
"prim_loop",
|
||||
"prim_max",
|
||||
"prim_min",
|
||||
"prim_shape",
|
||||
"prim_tolist",
|
||||
"prim_tuple_construct",
|
||||
"prim_type",
|
||||
"prim_unchecked_cast",
|
||||
"prim_uninitialized",
|
||||
"rand",
|
||||
"rand_like",
|
||||
"randint",
|
||||
"randint_like",
|
||||
"randn",
|
||||
"randn_like",
|
||||
"reciprocal",
|
||||
"reflection_pad",
|
||||
"relu",
|
||||
"relu6",
|
||||
"remainder",
|
||||
"repeat",
|
||||
"repeat_interleave",
|
||||
"replication_pad",
|
||||
"reshape",
|
||||
"reshape_as",
|
||||
"rnn_relu",
|
||||
"rnn_tanh",
|
||||
"roll",
|
||||
"rrelu",
|
||||
"rsqrt",
|
||||
"rsub",
|
||||
"scalar_tensor",
|
||||
"scatter",
|
||||
"scatter_add",
|
||||
"select",
|
||||
"selu",
|
||||
"sigmoid",
|
||||
"sign",
|
||||
"silu",
|
||||
"sin",
|
||||
"size",
|
||||
"slice",
|
||||
"softmax",
|
||||
"softplus",
|
||||
"softshrink",
|
||||
"sort",
|
||||
"split",
|
||||
"split_with_sizes",
|
||||
"sqrt",
|
||||
"square",
|
||||
"squeeze",
|
||||
"stack",
|
||||
"std",
|
||||
"std_mean",
|
||||
"sub",
|
||||
"t",
|
||||
"take",
|
||||
"tan",
|
||||
"tanh",
|
||||
"tanhshrink",
|
||||
"tensor",
|
||||
"threshold",
|
||||
"to",
|
||||
"topk",
|
||||
"transpose",
|
||||
"true_divide",
|
||||
"type_as",
|
||||
"unbind",
|
||||
"unfold",
|
||||
"unsafe_chunk",
|
||||
"unsafe_split",
|
||||
"unsafe_split_with_sizes",
|
||||
"unsqueeze",
|
||||
"unsupported_complex_operators",
|
||||
"unused",
|
||||
"upsample_bilinear2d",
|
||||
"upsample_linear1d",
|
||||
"upsample_nearest1d",
|
||||
"upsample_nearest2d",
|
||||
"upsample_nearest3d",
|
||||
"upsample_trilinear3d",
|
||||
"var",
|
||||
"var_mean",
|
||||
"view",
|
||||
"view_as",
|
||||
"where",
|
||||
"wrap_logical_op_with_cast_to",
|
||||
"wrap_logical_op_with_negation",
|
||||
"zero",
|
||||
"zeros",
|
||||
"zeros_like",
|
||||
# torch.onnx.utils
|
||||
"disable_apex_o2_state_dict_hook",
|
||||
"export",
|
||||
"export_to_pretty_string",
|
||||
"exporter_context",
|
||||
"is_in_onnx_export",
|
||||
"model_signature",
|
||||
"register_custom_op_symbolic",
|
||||
"select_model_mode_for_export",
|
||||
"setup_onnx_logging",
|
||||
"unconvertible_ops",
|
||||
"unpack_quantized_tensor",
|
||||
"warn_on_static_input_change",
|
||||
# torch.onnx.verification
|
||||
"check_export_model_diff",
|
||||
"verify",
|
||||
"verify_aten_graph",
|
||||
@ -1400,6 +1832,32 @@ coverage_ignore_functions = [
|
||||
"noop_context_fn",
|
||||
"set_checkpoint_early_stop",
|
||||
"set_device_states",
|
||||
# torch.utils.collect_env
|
||||
"check_release_file",
|
||||
"get_cachingallocator_config",
|
||||
"get_clang_version",
|
||||
"get_cmake_version",
|
||||
"get_conda_packages",
|
||||
"get_cpu_info",
|
||||
"get_cuda_module_loading_config",
|
||||
"get_cudnn_version",
|
||||
"get_env_info",
|
||||
"get_gcc_version",
|
||||
"get_gpu_info",
|
||||
"get_libc_version",
|
||||
"get_lsb_version",
|
||||
"get_mac_version",
|
||||
"get_nvidia_driver_version",
|
||||
"get_nvidia_smi",
|
||||
"get_os",
|
||||
"get_pip_packages",
|
||||
"get_platform",
|
||||
"get_pretty_env_info",
|
||||
"get_python_platform",
|
||||
"get_running_cuda_version",
|
||||
"get_windows_version",
|
||||
"is_xnnpack_available",
|
||||
"pretty_str",
|
||||
# torch.utils.cpp_backtrace
|
||||
"get_cpp_backtrace",
|
||||
# torch.utils.cpp_extension
|
||||
@ -1463,6 +1921,52 @@ coverage_ignore_functions = [
|
||||
"apply_shuffle_seed",
|
||||
"apply_shuffle_settings",
|
||||
"get_all_graph_pipes",
|
||||
# torch.utils.flop_counter
|
||||
"addmm_flop",
|
||||
"baddbmm_flop",
|
||||
"bmm_flop",
|
||||
"conv_backward_flop",
|
||||
"conv_flop",
|
||||
"conv_flop_count",
|
||||
"convert_num_with_suffix",
|
||||
"get_shape",
|
||||
"get_suffix_str",
|
||||
"mm_flop",
|
||||
"normalize_tuple",
|
||||
"register_flop_formula",
|
||||
"sdpa_backward_flop",
|
||||
"sdpa_backward_flop_count",
|
||||
"sdpa_flop",
|
||||
"sdpa_flop_count",
|
||||
"shape_wrapper",
|
||||
"transpose_shape",
|
||||
# torch.utils.hipify.hipify_python
|
||||
"add_dim3",
|
||||
"compute_stats",
|
||||
"extract_arguments",
|
||||
"file_add_header",
|
||||
"file_specific_replacement",
|
||||
"find_bracket_group",
|
||||
"find_closure_group",
|
||||
"find_parentheses_group",
|
||||
"fix_static_global_kernels",
|
||||
"get_hip_file_path",
|
||||
"hip_header_magic",
|
||||
"hipify",
|
||||
"is_caffe2_gpu_file",
|
||||
"is_cusparse_file",
|
||||
"is_out_of_place",
|
||||
"is_pytorch_file",
|
||||
"is_special_file",
|
||||
"match_extensions",
|
||||
"matched_files_iter",
|
||||
"openf",
|
||||
"preprocess_file_and_save_result",
|
||||
"preprocessor",
|
||||
"processKernelLaunches",
|
||||
"replace_extern_shared",
|
||||
"replace_math_functions",
|
||||
"str2bool",
|
||||
# torch.utils.hooks
|
||||
"unserializable_hook",
|
||||
"warn_if_has_hooks",
|
||||
|
||||
@ -12,37 +12,6 @@ These APIs are experimental and subject to change without notice.
|
||||
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
|
||||
```
|
||||
|
||||
## torch.fx.experimental.sym_node
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.sym_node
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
is_channels_last_contiguous_2d
|
||||
is_channels_last_contiguous_3d
|
||||
is_channels_last_strides_2d
|
||||
is_channels_last_strides_3d
|
||||
is_contiguous
|
||||
is_non_overlapping_and_dense_indicator
|
||||
method_to_operator
|
||||
sympy_is_channels_last_contiguous_2d
|
||||
sympy_is_channels_last_contiguous_3d
|
||||
sympy_is_channels_last_strides_2d
|
||||
sympy_is_channels_last_strides_3d
|
||||
sympy_is_channels_last_strides_generic
|
||||
sympy_is_contiguous
|
||||
sympy_is_contiguous_generic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.symbolic_shapes
|
||||
|
||||
```{eval-rst}
|
||||
@ -100,25 +69,6 @@ These APIs are experimental and subject to change without notice.
|
||||
rebind_unbacked
|
||||
resolve_unbacked_bindings
|
||||
is_accessor_node
|
||||
cast_symbool_to_symint_guardless
|
||||
create_contiguous
|
||||
error
|
||||
eval_guards
|
||||
eval_is_non_overlapping_and_dense
|
||||
find_symbol_binding_fx_nodes
|
||||
free_symbols
|
||||
free_unbacked_symbols
|
||||
fx_placeholder_targets
|
||||
fx_placeholder_vals
|
||||
guard_bool
|
||||
guard_float
|
||||
guard_int
|
||||
guard_scalar
|
||||
has_hint
|
||||
has_symbolic_sizes_strides
|
||||
is_nested_int
|
||||
is_symbol_binding_fx_node
|
||||
is_symbolic
|
||||
```
|
||||
|
||||
## torch.fx.experimental.proxy_tensor
|
||||
@ -141,46 +91,4 @@ These APIs are experimental and subject to change without notice.
|
||||
get_proxy_mode
|
||||
maybe_enable_thunkify
|
||||
maybe_disable_thunkify
|
||||
decompose
|
||||
disable_autocast_cache
|
||||
disable_proxy_modes_tracing
|
||||
extract_val
|
||||
fake_signature
|
||||
fetch_object_proxy
|
||||
fetch_sym_proxy
|
||||
has_proxy_slot
|
||||
is_sym_node
|
||||
maybe_handle_decomp
|
||||
proxy_call
|
||||
set_meta
|
||||
set_original_aten_op
|
||||
set_proxy_slot
|
||||
snapshot_fake
|
||||
```
|
||||
|
||||
## torch.fx.experimental.unification.unification_tools
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.fx.experimental.unification.unification_tools
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
assoc
|
||||
assoc_in
|
||||
dissoc
|
||||
first
|
||||
keyfilter
|
||||
keymap
|
||||
merge
|
||||
merge_with
|
||||
update_in
|
||||
valfilter
|
||||
valmap
|
||||
|
||||
@ -1134,6 +1134,7 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.refinement_types
|
||||
.. py:module:: torch.fx.experimental.rewriter
|
||||
.. py:module:: torch.fx.experimental.schema_type_annotation
|
||||
.. py:module:: torch.fx.experimental.sym_node
|
||||
.. py:module:: torch.fx.experimental.unification.core
|
||||
.. py:module:: torch.fx.experimental.unification.dispatch
|
||||
.. py:module:: torch.fx.experimental.unification.match
|
||||
@ -1143,6 +1144,7 @@ The set of leaf modules can be customized by overriding
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
|
||||
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
|
||||
.. py:module:: torch.fx.experimental.unification.unification_tools
|
||||
.. py:module:: torch.fx.experimental.unification.utils
|
||||
.. py:module:: torch.fx.experimental.unification.variable
|
||||
.. py:module:: torch.fx.experimental.unify_refinements
|
||||
|
||||
@ -134,23 +134,6 @@ Quantization to work with this as well.
|
||||
ObservationType
|
||||
```
|
||||
|
||||
## torch.ao.quantization.backend_config.utils
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.backend_config.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
entry_to_pretty_str
|
||||
pattern_to_human_readable
|
||||
remove_boolean_dispatch_from_name
|
||||
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.custom_config
|
||||
|
||||
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
|
||||
@ -171,30 +154,6 @@ This module contains a few CustomConfig classes that's used in both eager mode a
|
||||
StandaloneModuleConfigEntry
|
||||
```
|
||||
|
||||
## torch.ao.quantization.fx.utils
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.ao.quantization.fx.utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
all_node_args_except_first
|
||||
all_node_args_have_no_tensors
|
||||
collect_producer_nodes
|
||||
create_getattr_from_value
|
||||
create_node_from_old_node_preserve_meta
|
||||
graph_module_from_producer_nodes
|
||||
maybe_get_next_module
|
||||
node_arg_is_bias
|
||||
node_arg_is_weight
|
||||
return_arg_list
|
||||
```
|
||||
|
||||
## torch.ao.quantization.quantizer
|
||||
|
||||
```{eval-rst}
|
||||
|
||||
@ -19,91 +19,6 @@
|
||||
swap_tensors
|
||||
```
|
||||
|
||||
# torch.utils.collect_env
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.collect_env
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
check_release_file
|
||||
is_xnnpack_available
|
||||
pretty_str
|
||||
```
|
||||
|
||||
# torch.utils.flop_counter
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.flop_counter
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
baddbmm_flop
|
||||
bmm_flop
|
||||
conv_backward_flop
|
||||
conv_flop
|
||||
conv_flop_count
|
||||
register_flop_formula
|
||||
sdpa_backward_flop
|
||||
sdpa_backward_flop_count
|
||||
sdpa_flop
|
||||
sdpa_flop_count
|
||||
shape_wrapper
|
||||
```
|
||||
|
||||
# torch.utils.hipify.hipify_python
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.utils.hipify.hipify_python
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
compute_stats
|
||||
extract_arguments
|
||||
file_add_header
|
||||
file_specific_replacement
|
||||
find_bracket_group
|
||||
find_closure_group
|
||||
find_parentheses_group
|
||||
fix_static_global_kernels
|
||||
hip_header_magic
|
||||
hipify
|
||||
is_caffe2_gpu_file
|
||||
is_cusparse_file
|
||||
is_out_of_place
|
||||
is_pytorch_file
|
||||
is_special_file
|
||||
openf
|
||||
preprocess_file_and_save_result
|
||||
preprocessor
|
||||
processKernelLaunches
|
||||
replace_extern_shared
|
||||
replace_math_functions
|
||||
str2bool
|
||||
```
|
||||
|
||||
|
||||
<!-- This module needs to be documented. Adding here in the meantime
|
||||
for tracking purposes -->
|
||||
```{eval-rst}
|
||||
@ -128,6 +43,7 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
|
||||
.. py:module:: torch.utils.bundled_inputs
|
||||
.. py:module:: torch.utils.checkpoint
|
||||
.. py:module:: torch.utils.collect_env
|
||||
.. py:module:: torch.utils.cpp_backtrace
|
||||
.. py:module:: torch.utils.cpp_extension
|
||||
.. py:module:: torch.utils.data.backward_compatibility
|
||||
@ -164,8 +80,10 @@ for tracking purposes -->
|
||||
.. py:module:: torch.utils.data.sampler
|
||||
.. py:module:: torch.utils.dlpack
|
||||
.. py:module:: torch.utils.file_baton
|
||||
.. py:module:: torch.utils.flop_counter
|
||||
.. py:module:: torch.utils.hipify.constants
|
||||
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
|
||||
.. py:module:: torch.utils.hipify.hipify_python
|
||||
.. py:module:: torch.utils.hipify.version
|
||||
.. py:module:: torch.utils.hooks
|
||||
.. py:module:: torch.utils.jit.log_extract
|
||||
|
||||
33
setup.py
33
setup.py
@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1646,7 +1615,6 @@ def main() -> None:
|
||||
mirror_files_into_torchgen()
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
@ -1681,7 +1649,6 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
@ -204,16 +204,14 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
self.assertTrue(b_dt.grad is not None)
|
||||
self.assertTrue(x_dt.grad is None)
|
||||
|
||||
def _run_single_arg_fwd(
|
||||
self, model, arg, placements=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
|
||||
device_mesh = self.build_device_mesh()
|
||||
model_copy = copy.deepcopy(model).to(device=self.device_type)
|
||||
dist_model = distribute_module(model, device_mesh, _conv_fn)
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, placements)
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
|
||||
out_dt = dist_model(arg_dt.to(device=self.device_type))
|
||||
out = model_copy(arg_dt.full_tensor())
|
||||
out = model_copy(arg)
|
||||
return (out_dt.full_tensor(), out)
|
||||
|
||||
@with_comms
|
||||
@ -221,20 +219,22 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
model = nn.Conv1d(64, 64, 3, padding=1)
|
||||
x = torch.randn(1, 64, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt, out)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
@with_comms
|
||||
def test_conv3d(self):
|
||||
model = nn.Conv3d(64, 64, 3, padding=1)
|
||||
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
|
||||
self.assertEqual(out_dt, out)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
|
||||
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistConvolutionOpsTest,
|
||||
# Send / recv ops are not supported
|
||||
skipped_tests=[
|
||||
"test_conv1d",
|
||||
"test_conv3d",
|
||||
"test_conv_backward_none_grad_inp",
|
||||
"test_depthwise_convolution",
|
||||
"test_downsampling_convolution",
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
@ -21,8 +22,9 @@ from torch.distributed.tensor import (
|
||||
)
|
||||
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial
|
||||
from torch.distributed.tensor.placement_types import _StridedShard
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -33,11 +35,7 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
generate_shard_orders,
|
||||
make_full_tensor,
|
||||
map_local_tensor_for_rank,
|
||||
patched_distribute_tensor as _distribute_tensor,
|
||||
redistribute,
|
||||
with_comms,
|
||||
)
|
||||
from torch.utils._debug_mode import DebugMode
|
||||
@ -787,6 +785,88 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
else:
|
||||
return ""
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def redistribute(
|
||||
self,
|
||||
dtensor_input,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""
|
||||
wrapper function to support shard_order for redistribution
|
||||
This is a simpler version of Redistribute, only considers the forward.
|
||||
"""
|
||||
if placements is None:
|
||||
placements = self._shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
old_spec = dtensor_input._spec
|
||||
new_spec = copy.deepcopy(old_spec)
|
||||
new_spec.placements = placements
|
||||
if shard_order is not None:
|
||||
new_spec.shard_order = shard_order
|
||||
else:
|
||||
new_spec.shard_order = ()
|
||||
if old_spec == new_spec:
|
||||
return dtensor_input
|
||||
dtensor_input = DTensor.from_local(
|
||||
redistribute_local_tensor(
|
||||
dtensor_input.to_local(),
|
||||
old_spec,
|
||||
new_spec,
|
||||
use_graph_based_transform=use_graph_based_transform,
|
||||
),
|
||||
device_mesh,
|
||||
)
|
||||
dtensor_input._spec = copy.deepcopy(new_spec)
|
||||
return dtensor_input # returns DTensor
|
||||
|
||||
# TODO(zpcore): remove once the native distribute_tensor supports
|
||||
# shard_order arg
|
||||
def distribute_tensor(
|
||||
self,
|
||||
input_tensor,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""wrapper function to support shard_order for tensor distribution"""
|
||||
if placements is None:
|
||||
placements = self._shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
|
||||
# fix the shard order
|
||||
return self.redistribute(
|
||||
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
|
||||
)
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def full_tensor(self, dtensor_input):
|
||||
"""wrapper function to support DTensor.full_tensor"""
|
||||
return self.redistribute(
|
||||
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
|
||||
).to_local()
|
||||
|
||||
def _shard_order_to_placement(self, shard_order, mesh):
|
||||
"""convert shard_order to placement with only Replicate() and Shard()"""
|
||||
placements = [Replicate() for _ in range(mesh.ndim)]
|
||||
if shard_order is not None:
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
for mesh_dim in mesh_dims:
|
||||
placements[mesh_dim] = Shard(tensor_dim)
|
||||
return tuple(placements)
|
||||
|
||||
def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
|
||||
"""Convert shard_order dict to ShardOrder"""
|
||||
return tuple(
|
||||
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
|
||||
for tensor_dim, mesh_dims in shard_order.items()
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_ordered_redistribute(self):
|
||||
"""Test ordered redistribution with various sharding syntaxes"""
|
||||
@ -847,11 +927,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
|
||||
sharding_src_dst_pairs_with_expected_trace
|
||||
):
|
||||
sharded_dt = _distribute_tensor(
|
||||
sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(), mesh, src_placement, shard_order=src_order
|
||||
)
|
||||
with DebugMode(record_torchfunction=False) as debug_mode:
|
||||
sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order)
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt, mesh, dst_placement, dst_order
|
||||
)
|
||||
trace_str = self._extract_redistribute_trace_from_debug_mode(
|
||||
debug_mode.debug_string()
|
||||
)
|
||||
@ -875,11 +957,49 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
trace_str,
|
||||
"""S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
|
||||
)
|
||||
expected_dt = _distribute_tensor(
|
||||
expected_dt = self.distribute_tensor(
|
||||
input_data.clone(), mesh, dst_placement, shard_order=dst_order
|
||||
)
|
||||
self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
|
||||
|
||||
def generate_shard_orders(self, mesh, tensor_rank):
|
||||
# Generate all possible sharding placement of tensor with rank
|
||||
# `tensor_rank` over mesh.
|
||||
def _split_list(lst: list, N: int):
|
||||
def compositions(n, k):
|
||||
if k == 1:
|
||||
yield [n]
|
||||
else:
|
||||
for i in range(1, n - k + 2):
|
||||
for tail in compositions(n - i, k - 1):
|
||||
yield [i] + tail
|
||||
|
||||
length = len(lst)
|
||||
for comp in compositions(length, N):
|
||||
result = []
|
||||
start = 0
|
||||
for size in comp:
|
||||
result.append(lst[start : start + size])
|
||||
start += size
|
||||
yield result
|
||||
|
||||
all_mesh = list(range(mesh.ndim))
|
||||
all_device_order = list(itertools.permutations(all_mesh))
|
||||
for device_order in all_device_order:
|
||||
# split on device orders, and assign each device order segment to a tensor dim
|
||||
for num_split in range(1, mesh.ndim + 1):
|
||||
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
|
||||
for tensor_dims in itertools.combinations(
|
||||
range(tensor_rank), len(splitted_list)
|
||||
):
|
||||
shard_order = {}
|
||||
assert len(tensor_dims) == len(splitted_list)
|
||||
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
|
||||
shard_order[tensor_dim] = device_order[
|
||||
mesh_dims[0] : mesh_dims[-1] + 1
|
||||
]
|
||||
yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
|
||||
|
||||
@with_comms
|
||||
def test_generate_shard_orders(self):
|
||||
"""Check if `generate_shard_orders` generates unique sharding combinations"""
|
||||
@ -892,7 +1012,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
]
|
||||
for test_input in test_inputs:
|
||||
all_combinations = []
|
||||
for shard_order in generate_shard_orders(
|
||||
for shard_order in self.generate_shard_orders(
|
||||
test_input["mesh"], test_input["tensor_rank"]
|
||||
):
|
||||
all_combinations.append(shard_order) # noqa: PERF402
|
||||
@ -942,12 +1062,12 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
input_data = torch.randn(tensor_shape, device=self.device_type)
|
||||
tensor_rank = input_data.ndim
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = generate_shard_orders(mesh, tensor_rank)
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = _distribute_tensor(
|
||||
sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(), mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(make_full_tensor(sharded_dt), input_data)
|
||||
self.assertEqual(self.full_tensor(sharded_dt), input_data)
|
||||
|
||||
# 2. Verify the correctness of redistribution from DTensor to DTensor.
|
||||
# This test repeatedly redistributes a DTensor to various ordered
|
||||
@ -958,20 +1078,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
tensor_rank = input_data.ndim
|
||||
prev_sharded_dt = None
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = generate_shard_orders(mesh, tensor_rank)
|
||||
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
|
||||
for shard_order in shard_orders:
|
||||
if prev_sharded_dt is None:
|
||||
prev_sharded_dt = _distribute_tensor(
|
||||
prev_sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(),
|
||||
mesh,
|
||||
placements=None,
|
||||
shard_order=shard_order,
|
||||
)
|
||||
else:
|
||||
sharded_dt = redistribute(
|
||||
sharded_dt = self.redistribute(
|
||||
prev_sharded_dt, mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(make_full_tensor(sharded_dt), input_data)
|
||||
self.assertEqual(self.full_tensor(sharded_dt), input_data)
|
||||
prev_sharded_dt = sharded_dt
|
||||
|
||||
@with_comms
|
||||
@ -1016,13 +1136,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
local_tensor = torch.randn(shape, device=self.device_type)
|
||||
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
|
||||
with maybe_disable_local_tensor_mode():
|
||||
shard_orders = generate_shard_orders(mesh, len(shape))
|
||||
shard_orders = self.generate_shard_orders(mesh, len(shape))
|
||||
for shard_order in shard_orders:
|
||||
sharded_dt = redistribute(
|
||||
sharded_dt = self.redistribute(
|
||||
full_tensor, mesh, placements=None, shard_order=shard_order
|
||||
)
|
||||
self.assertEqual(
|
||||
make_full_tensor(sharded_dt), make_full_tensor(full_tensor)
|
||||
self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
@ -1032,20 +1152,24 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_ordered_redistribute_for_special_placement(self):
|
||||
"""Test ordered redistribution with special placement"""
|
||||
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
|
||||
|
||||
torch.manual_seed(21)
|
||||
mesh = init_device_mesh(self.device_type, (8,))
|
||||
input_data = torch.randn((8, 8), device=self.device_type)
|
||||
src_placement = [Shard(1)]
|
||||
tgt_placement = [
|
||||
(MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
|
||||
(_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
|
||||
]
|
||||
sharded_dt = _distribute_tensor(
|
||||
sharded_dt = self.distribute_tensor(
|
||||
input_data.clone(),
|
||||
mesh,
|
||||
src_placement,
|
||||
shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
|
||||
)
|
||||
sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None)
|
||||
sharded_dt = self.redistribute(
|
||||
sharded_dt, mesh, tgt_placement, shard_order=None
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_shard_order_same_data_as_strided_shard(self):
|
||||
@ -1055,7 +1179,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
|
||||
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
|
||||
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
|
||||
# specify right-to-left order use ordered shard
|
||||
x_ordered_dt = _distribute_tensor(
|
||||
x_ordered_dt = self.distribute_tensor(
|
||||
x,
|
||||
device_mesh,
|
||||
placements=[Shard(0), Shard(0)],
|
||||
|
||||
@ -34,10 +34,6 @@ from torch.distributed.tensor.placement_types import (
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
generate_shard_orders,
|
||||
LocalDTensorTestBase,
|
||||
patched_distribute_tensor as _distribute_tensor,
|
||||
shard_order_to_placement,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
@ -778,63 +774,6 @@ class TestStridedSharding(DTensorTestBase):
|
||||
self.assertEqual(dtensor.full_tensor(), tensor)
|
||||
|
||||
|
||||
class Test_StridedShard_with_shard_order(LocalDTensorTestBase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 32
|
||||
|
||||
@with_comms
|
||||
def test_StridedShard_to_shard_order(self):
|
||||
with LocalTensorMode(ranks=self.world_size):
|
||||
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2))
|
||||
shard_iter = generate_shard_orders(mesh, 3)
|
||||
# It takes ~4.8h to complete total 2520 shard order combinations here
|
||||
# using LocalTensor. So we only randomly pick 25 shard orders to test.
|
||||
all_shard_order = list(shard_iter)
|
||||
import random
|
||||
|
||||
random.seed(42)
|
||||
shard_order_choices = random.sample(
|
||||
all_shard_order, min(25, len(all_shard_order))
|
||||
)
|
||||
|
||||
x = torch.randn(32, 32, 32)
|
||||
for shard_order in shard_order_choices:
|
||||
a = _distribute_tensor(x, mesh, None, shard_order)
|
||||
|
||||
placement_without_stridedshard = shard_order_to_placement(
|
||||
shard_order, mesh
|
||||
)
|
||||
placements_with_stridedshard = (
|
||||
DTensorSpec._convert_shard_order_to_StridedShard(
|
||||
shard_order, placement_without_stridedshard, mesh
|
||||
)
|
||||
)
|
||||
b = distribute_tensor(x, mesh, placements_with_stridedshard)
|
||||
shard_order_from_stridedshard = (
|
||||
DTensorSpec._maybe_convert_StridedShard_to_shard_order(
|
||||
placements_with_stridedshard, mesh
|
||||
)
|
||||
)
|
||||
self.assertEqual(shard_order, shard_order_from_stridedshard)
|
||||
self.assertEqual(a.to_local(), b.to_local())
|
||||
|
||||
@with_comms
|
||||
def test_StridedShard_not_convertible_to_shard_order(self):
|
||||
with LocalTensorMode(ranks=self.world_size):
|
||||
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8))
|
||||
unconvertible_placements_list = [
|
||||
[_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)],
|
||||
[_StridedShard(0, split_factor=2), Shard(1)],
|
||||
[_StridedShard(1, split_factor=16), Shard(1)],
|
||||
]
|
||||
for placements in unconvertible_placements_list:
|
||||
shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order(
|
||||
tuple(placements), mesh
|
||||
)
|
||||
self.assertIsNone(shard_order)
|
||||
|
||||
|
||||
class Test2DStridedLocalShard(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
||||
@ -206,10 +206,6 @@ class TestPyCodeCache(TestCase):
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
|
||||
if torch.xpu.is_available():
|
||||
wrapper_path = wrapper_path.splitlines()[-1]
|
||||
hit = hit.splitlines()[-1]
|
||||
self.assertEqual(hit, "1")
|
||||
|
||||
with open(wrapper_path) as f:
|
||||
|
||||
@ -1,154 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import ensure_cute_available
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
|
||||
"CuTeDSL library or Blackwell device not available",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestCuTeDSLGroupedGemm(InductorTestCase):
|
||||
def _get_inputs(
|
||||
self,
|
||||
group_size: int,
|
||||
M_hint: int,
|
||||
K: int,
|
||||
N: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
alignment: int = 16,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# --- Random, tile-aligned M sizes ---
|
||||
M_sizes = (
|
||||
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
|
||||
* alignment
|
||||
)
|
||||
|
||||
M_total = torch.sum(M_sizes).item()
|
||||
|
||||
# --- Construct input tensors ---
|
||||
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
|
||||
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
# --- Build offsets (no leading zero, strictly increasing) ---
|
||||
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
|
||||
|
||||
return (A, B, offsets)
|
||||
|
||||
@parametrize("group_size", (2, 8))
|
||||
@parametrize("M_hint", (256, 1024))
|
||||
@parametrize("K", (64, 128))
|
||||
@parametrize("N", (128, 256))
|
||||
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# Eager execution
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# Test with Cute backend
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
|
||||
@parametrize("layout_B", ("contiguous", "broadcasted"))
|
||||
def test_grouped_gemm_assorted_layouts(
|
||||
self,
|
||||
layout_A: str,
|
||||
layout_B: str,
|
||||
):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
G, K, N = 8, 64, 128
|
||||
M_sizes = [128] * G
|
||||
sum_M = sum(M_sizes)
|
||||
offsets = torch.tensor(
|
||||
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
|
||||
A = A_base
|
||||
|
||||
if layout_A == "offset":
|
||||
# allocate bigger buffer than needed, use nonzero storage offset
|
||||
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
|
||||
offset = 128 # skip first 128 elements
|
||||
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
|
||||
elif layout_A == "padded":
|
||||
# simulate row pitch > K (row_stride = K + pad)
|
||||
row_pitch = K + 8
|
||||
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
|
||||
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
|
||||
elif layout_A == "view":
|
||||
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
|
||||
A = A_storage.view(sum_M, K)
|
||||
assert A._base is not None
|
||||
assert A.shape == (sum_M, K)
|
||||
|
||||
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
if layout_B == "broadcasted":
|
||||
# Broadcast B across groups (zero stride along G)
|
||||
B = B[0].expand(G, K, N)
|
||||
assert B.stride(0) == 0
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# --- eager ---
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# --- compiled (CUTE backend) ---
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -482,8 +482,8 @@ class TestExecutionTrace(TestCase):
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
|
||||
@unittest.skipIf(
|
||||
(not has_triton()) or (not TEST_CUDA),
|
||||
"need triton and device CUDA availability to run",
|
||||
(not has_triton()) or (not TEST_CUDA and not TEST_XPU),
|
||||
"need triton and device(CUDA or XPU) availability to run",
|
||||
)
|
||||
@skipCPUIf(True, "skip CPU device for testing profiling triton")
|
||||
def test_triton_fx_graph_with_et(self, device):
|
||||
|
||||
@ -2005,10 +2005,6 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
||||
report = json.load(f)
|
||||
self._validate_basic_json(report["traceEvents"], with_cuda)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch.xpu.is_available(),
|
||||
"XPU Trace event ends too late! Refer https://github.com/intel/torch-xpu-ops/issues/2263",
|
||||
)
|
||||
@unittest.skipIf(not kineto_available(), "Kineto is required")
|
||||
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
||||
def test_basic_chrome_trace(self):
|
||||
@ -2162,10 +2158,7 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
||||
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
||||
def test_basic_profile(self):
|
||||
# test a really basic profile to make sure no erroneous aten ops are run
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
self.assertIsNotNone(acc)
|
||||
device = acc.type
|
||||
x = torch.randn(4, device=device)
|
||||
x = torch.randn(4, device="cuda")
|
||||
with torch.profiler.profile(with_stack=True) as p:
|
||||
x *= 2
|
||||
names = [e.name for e in p.events()]
|
||||
@ -2232,7 +2225,6 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
||||
@unittest.skipIf(
|
||||
torch.cuda.is_available(), "CUDA complains about forking after init"
|
||||
)
|
||||
@unittest.skipIf(torch.xpu.is_available(), "XPU complains about forking after init")
|
||||
@unittest.skipIf(IS_WINDOWS, "can't use os.fork() on Windows")
|
||||
def test_forked_process(self):
|
||||
# Induce a pid cache by running the profiler with payload
|
||||
|
||||
@ -263,7 +263,13 @@ S390X_BLOCKLIST = [
|
||||
|
||||
XPU_BLOCKLIST = [
|
||||
"test_autograd",
|
||||
"profiler/test_cpp_thread",
|
||||
"profiler/test_execution_trace",
|
||||
"profiler/test_memory_profiler",
|
||||
"profiler/test_profiler",
|
||||
"profiler/test_profiler_tree",
|
||||
"profiler/test_record_function",
|
||||
"profiler/test_torch_tidy",
|
||||
"test_openreg",
|
||||
]
|
||||
|
||||
|
||||
@ -1,236 +1,245 @@
|
||||
{
|
||||
"EndToEndLSTM (__main__.RNNTest)": 190.48799641927084,
|
||||
"MultiheadAttention (__main__.ModulesTest)": 141.2663370768229,
|
||||
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 82.87333234151204,
|
||||
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 70.6538565499442,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 123.34033711751302,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 171.25450134277344,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 119.71899922688802,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.35733322870163,
|
||||
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.64533233642578,
|
||||
"test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.672952016194664,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 138.04000091552734,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 172.1344985961914,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 114.02050018310547,
|
||||
"test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.25642830984933,
|
||||
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.3350003560384,
|
||||
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 120.95249938964844,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.97774887084961,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.90774917602539,
|
||||
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1144.3935089111328,
|
||||
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 222.58500061035156,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.10033162434894,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 517.1875050862631,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 113.88125228881836,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 235.77350616455078,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.6155014038086,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.63325119018555,
|
||||
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 216.2968317667643,
|
||||
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 153.0915012359619,
|
||||
"test_cat_2k_args (__main__.TestTEFuserDynamic)": 108.80471753561869,
|
||||
"test_cat_2k_args (__main__.TestTEFuserStatic)": 102.20949847949669,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 311.7026621500651,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 395.0001729329427,
|
||||
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 348.6218566894531,
|
||||
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 98.71574974060059,
|
||||
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 97.68499946594238,
|
||||
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.0557508468628,
|
||||
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 65.86899948120117,
|
||||
"test_comprehensive_gradient_cuda_complex64 (__main__.TestDecompCUDA)": 97.15880012512207,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 103.20700073242188,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 102.74033610026042,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 460.4286702473958,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 435.62066650390625,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 287.3090057373047,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 265.1860008239746,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1235.7365112304688,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.20825004577637,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1281.2615051269531,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.90750026702881,
|
||||
"test_comprehensive_linalg_householder_product_cuda_complex64 (__main__.TestDecompCUDA)": 79.04633331298828,
|
||||
"test_comprehensive_linalg_lu_factor_ex_cuda_complex128 (__main__.TestDecompCUDA)": 68.10879821777344,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.43025207519531,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.94575023651123,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.93649864196777,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.46275043487549,
|
||||
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 64.10650062561035,
|
||||
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.03124904632568,
|
||||
"test_comprehensive_linalg_svd_cuda_float64 (__main__.TestDecompCUDA)": 64.32800025939942,
|
||||
"test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 96.41353665865384,
|
||||
"test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 100.17661388103778,
|
||||
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 110.95025062561035,
|
||||
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 108.06550025939941,
|
||||
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 104.24150085449219,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.453749656677246,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 61.739999771118164,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 69.96549987792969,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 113.65749931335449,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 106.57500076293945,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 117.54049682617188,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 116.19766489664714,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 272.48475646972656,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 248.12175369262695,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 79.66900062561035,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 81.52649879455566,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 79.29400062561035,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.40349960327148,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 128.42924880981445,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 125.03675079345703,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1264.9732360839844,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1250.7332458496094,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1255.0684814453125,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 574.4627532958984,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 581.7282485961914,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 65.052001953125,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 61.19200134277344,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 63.16874885559082,
|
||||
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 62.39250183105469,
|
||||
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 113.32574844360352,
|
||||
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 113.91499900817871,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 74.42549800872803,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 76.1560001373291,
|
||||
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.76750087738037,
|
||||
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 70.69724941253662,
|
||||
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 69.87625026702881,
|
||||
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 80.2542495727539,
|
||||
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 69.0419979095459,
|
||||
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 117.03342655726841,
|
||||
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 289.50213841029574,
|
||||
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 67.38800048828125,
|
||||
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 145.27399444580078,
|
||||
"test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 66.9245999654134,
|
||||
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 151.91099548339844,
|
||||
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 92.79549789428711,
|
||||
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.60149955749512,
|
||||
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 69.27724676392972,
|
||||
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 76.24971498761859,
|
||||
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 81.93449974060059,
|
||||
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.87700080871582,
|
||||
"test_count_nonzero_all (__main__.TestBool)": 631.2585144042969,
|
||||
"test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 61.042999267578125,
|
||||
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.49850082397461,
|
||||
"test_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 93.03299713134766,
|
||||
"test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 228.46711820714614,
|
||||
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 286.29998779296875,
|
||||
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 68.43842806134906,
|
||||
"test_fail_random.py (__main__.TestTyping)": 74.83523060725285,
|
||||
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 72.84900093078613,
|
||||
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 75.86675071716309,
|
||||
"test_fuse_large_params_cpu (__main__.CpuTests)": 151.4199981689453,
|
||||
"test_fuse_large_params_cuda (__main__.GPUTests)": 60.351999282836914,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 158.3622828892299,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 149.6796646118164,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.97800064086914,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 114.8385009765625,
|
||||
"test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 84.69736822027909,
|
||||
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.62700080871582,
|
||||
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 89.197998046875,
|
||||
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 96.46900177001953,
|
||||
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 187.83824920654297,
|
||||
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.49449920654297,
|
||||
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 124.90424919128418,
|
||||
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 518.4157485961914,
|
||||
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 304.6440022786458,
|
||||
"test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 143.82052836698645,
|
||||
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 77.4985705784389,
|
||||
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 76.06225109100342,
|
||||
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 138.9222858973912,
|
||||
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 120.62233225504558,
|
||||
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 148.1219940185547,
|
||||
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 109.34200286865234,
|
||||
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 119.36233266194661,
|
||||
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 127.95700073242188,
|
||||
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 61.64850175380707,
|
||||
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 105.3174296787807,
|
||||
"test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 585.9210001627604,
|
||||
"test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 504.3250020345052,
|
||||
"test_lstm_cpu (__main__.TestMkldnnCPU)": 86.21566645304362,
|
||||
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 129.277715410505,
|
||||
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 64.24800109863281,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 77.23899841308594,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_True (__main__.TestCKBackend)": 65.15649795532227,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 62.579833984375,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.6555004119873,
|
||||
"test_pattern_matcher_multi_user_cpu (__main__.CpuTritonTests)": 142.21566772460938,
|
||||
"test_proper_exit (__main__.TestDataLoader)": 267.74214717320035,
|
||||
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 266.6539971487863,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 101.97100067138672,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.3346659342448,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.50300216674805,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 104.61333465576172,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.41133371988933,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.37100219726562,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.30900065104167,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.61750030517578,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 79.33600234985352,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 101.2393315633138,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 103.18400192260742,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.4114990234375,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.52833302815755,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.72700119018555,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 100.61966705322266,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.2750015258789,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.17449951171875,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.96749877929688,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 106.44049835205078,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 101.7173334757487,
|
||||
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 531.5236612955729,
|
||||
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1077.4210205078125,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 812.0880126953125,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1347.9365234375,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 88.93533070882161,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 269.01949310302734,
|
||||
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.99799601236978,
|
||||
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 232.36275100708008,
|
||||
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 69.80400085449219,
|
||||
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 134.3415012359619,
|
||||
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.51749992370605,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 91.21066792805989,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 170.97775268554688,
|
||||
"test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.608266321818036,
|
||||
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.62575149536133,
|
||||
"test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 63.59499969482422,
|
||||
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 88.68299865722656,
|
||||
"test_rnn_decomp_module_nn_LSTM_train_mode_cuda_float32 (__main__.TestDecompCUDA)": 91.50320053100586,
|
||||
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 66.10774898529053,
|
||||
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 66.20533180236816,
|
||||
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 243.1092529296875,
|
||||
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 105.01200103759766,
|
||||
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.93685695103237,
|
||||
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 142.38899993896484,
|
||||
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 119.90166600545247,
|
||||
"test_sort_bool_cpu (__main__.CpuTritonTests)": 346.2856750488281,
|
||||
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 423.09974098205566,
|
||||
"test_sort_stable_cuda (__main__.GPUTests)": 117.61659927368164,
|
||||
"test_sort_transpose_cpu (__main__.CpuTritonTests)": 378.31200154622394,
|
||||
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 222.822007894516,
|
||||
"test_terminate_handler_on_crash (__main__.TestTorch)": 143.31728431156702,
|
||||
"test_terminate_signal (__main__.ForkTest)": 168.20485967184817,
|
||||
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 168.19242484867573,
|
||||
"test_terminate_signal (__main__.SpawnTest)": 172.16428443363733,
|
||||
"test_thnn_conv_strided_padded_dilated (__main__.TestConvolutionNN)": 93.30639710426331,
|
||||
"test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 163.89743041992188,
|
||||
"test_train_parity_with_activation_checkpointing (__main__.TestFullyShard1DTrainingCompose)": 60.47671399797712,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 63.39550018310547,
|
||||
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 173.53924942016602,
|
||||
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 175.3212537765503,
|
||||
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.20649909973145,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 99.9885025024414,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 71.64024829864502,
|
||||
"test_view_ops (__main__.TestViewOpsWithLocalTensor)": 73.45887422561646,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.75249862670898,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 61.858001708984375,
|
||||
"test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 65.11023766653878,
|
||||
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 66.35274982452393,
|
||||
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 61.196499824523926,
|
||||
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.75380906604585,
|
||||
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 73.64649868011475,
|
||||
"test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 75.09799966358003,
|
||||
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.51450157165527,
|
||||
"test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.21433276221866,
|
||||
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 73.20024871826172,
|
||||
"test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 88.1349983215332,
|
||||
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.89924907684326,
|
||||
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.32975196838379,
|
||||
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 120.09600067138672
|
||||
"EndToEndLSTM (__main__.RNNTest)": 207.89400227864584,
|
||||
"MultiheadAttention (__main__.ModulesTest)": 141.1396687825521,
|
||||
"test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 214.02366638183594,
|
||||
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 77.26125049591064,
|
||||
"test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 116.37000020345052,
|
||||
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 69.25722334120009,
|
||||
"test_after_aot_gpu_runtime_error (__main__.MinifierIsolateTests)": 65.84466807047527,
|
||||
"test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 178.41399637858072,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.55014337812151,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 122.18047623407273,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 192.6405719575428,
|
||||
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.27904801141648,
|
||||
"test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.906999588012695,
|
||||
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 62.244998931884766,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 150.04100036621094,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 191.85050201416016,
|
||||
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.9276631673177,
|
||||
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.31450271606445,
|
||||
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 125.24066416422527,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.47783279418945,
|
||||
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.46250025431316,
|
||||
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1031.0534973144531,
|
||||
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 239.67400105794272,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 495.0447726779514,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 490.18524169921875,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 144.06477737426758,
|
||||
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 342.20416259765625,
|
||||
"test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 62.01366678873698,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 71.07200050354004,
|
||||
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 73.9221674601237,
|
||||
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 226.0122528076172,
|
||||
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 144.97249857584634,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 303.20537185668945,
|
||||
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 386.0518798828125,
|
||||
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 291.2442270914714,
|
||||
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 95.87866719563802,
|
||||
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 98.38716634114583,
|
||||
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 69.08016649881999,
|
||||
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 69.88233311971028,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 104.17599995930989,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 97.41800308227539,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 474.6719970703125,
|
||||
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 440.4375,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 293.3983332316081,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 238.7328338623047,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1218.4906717936199,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.73516782124837,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1156.0123494466145,
|
||||
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.13916714986165,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.90450032552083,
|
||||
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.42100016276042,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.98883310953777,
|
||||
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 73.34433364868164,
|
||||
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.38016573588053,
|
||||
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.52783330281575,
|
||||
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 111.06333287556966,
|
||||
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 110.19833374023438,
|
||||
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.10083134969075,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.23766644795736,
|
||||
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 70.18666712443034,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 62.61399841308594,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 67.7816670735677,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 121.6183344523112,
|
||||
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 107.30266698201497,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 130.8143310546875,
|
||||
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 127.27633412679036,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 303.55183664957684,
|
||||
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 234.41216532389322,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 85.3436673482259,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 80.9688326517741,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 82.55149968465169,
|
||||
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.37966791788737,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 129.88233184814453,
|
||||
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 129.4015007019043,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1282.3826497395833,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1270.64599609375,
|
||||
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1297.9046630859375,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 545.2034962972006,
|
||||
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 572.5616760253906,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 64.40316645304362,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.68383344014485,
|
||||
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.48333422342936,
|
||||
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 61.959999084472656,
|
||||
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 105.79100036621094,
|
||||
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 122.34666570027669,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 68.7205015818278,
|
||||
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.2183329264323,
|
||||
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.86883227030437,
|
||||
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 77.48183314005534,
|
||||
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 79.1564998626709,
|
||||
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 160.41250228881836,
|
||||
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 79.10633341471355,
|
||||
"test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 60.106833140055336,
|
||||
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 221.3586196899414,
|
||||
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 504.3203754425049,
|
||||
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 78.03233337402344,
|
||||
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 152.302001953125,
|
||||
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 152.99433390299478,
|
||||
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 96.25399971008301,
|
||||
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 75.70275068283081,
|
||||
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 139.14399747674665,
|
||||
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 72.7847490310669,
|
||||
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 91.59966786702473,
|
||||
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 87.57833353678386,
|
||||
"test_count_nonzero_all (__main__.TestBool)": 664.9986343383789,
|
||||
"test_cp_flex_attention_document_mask (__main__.CPFlexAttentionTest)": 78.31500244140625,
|
||||
"test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 385.24249792099,
|
||||
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.70466740926106,
|
||||
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 685.0679931640625,
|
||||
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 86.26266733805339,
|
||||
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 292.93699645996094,
|
||||
"test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 66.84199905395508,
|
||||
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 69.56212568283081,
|
||||
"test_fail_creation_ops.py (__main__.TestTyping)": 69.80560022989908,
|
||||
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 73.36666552225749,
|
||||
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 90.40366744995117,
|
||||
"test_fuse_large_params_cpu (__main__.CpuTests)": 132.73199844360352,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 150.16662406921387,
|
||||
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 159.28499794006348,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 165.19283294677734,
|
||||
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 151.12366739908853,
|
||||
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.61699930826823,
|
||||
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.00600179036458,
|
||||
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.3759994506836,
|
||||
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 190.89249674479166,
|
||||
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 149.6598358154297,
|
||||
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 146.07766723632812,
|
||||
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 532.8139902750651,
|
||||
"test_graph_partition_refcount_cuda (__main__.GPUTests)": 69.78400001525878,
|
||||
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 267.04988850487604,
|
||||
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 273.54955800374347,
|
||||
"test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.84733072916666,
|
||||
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 326.0143330891927,
|
||||
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.96037435531616,
|
||||
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 77.44933319091797,
|
||||
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 126.81488884819879,
|
||||
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 118.70199839274089,
|
||||
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 129.20266723632812,
|
||||
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 97.18800099690755,
|
||||
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 130.3183339436849,
|
||||
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 140.43233235677084,
|
||||
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 293.122774971856,
|
||||
"test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 63.835832277933754,
|
||||
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.77049922943115,
|
||||
"test_lstm_cpu (__main__.TestMkldnnCPU)": 100.89649963378906,
|
||||
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 140.07424926757812,
|
||||
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 72.90299733479817,
|
||||
"test_max_autotune_addmm_search_space_EXHAUSTIVE_dynamic_True (__main__.TestMaxAutotuneSubproc)": 82.62433369954427,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 87.51499938964844,
|
||||
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_True_use_aoti_True (__main__.TestCKBackend)": 71.22416591644287,
|
||||
"test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 424.50966389973956,
|
||||
"test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 134.14600626627603,
|
||||
"test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 358.88099161783856,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.58866712782118,
|
||||
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 62.68674945831299,
|
||||
"test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 65.85794713936355,
|
||||
"test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTest)": 103.6923344930013,
|
||||
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTest)": 187.6953328450521,
|
||||
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 370.27442932128906,
|
||||
"test_proper_exit (__main__.TestDataLoader)": 227.83111148410373,
|
||||
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 227.1901126437717,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.52099990844727,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.50249862670898,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 92.52400207519531,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.75499725341797,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 107.40500259399414,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 83.80450057983398,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 107.46599833170573,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.65650177001953,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 83.4114990234375,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.47100067138672,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 108.55533345540364,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 89.23666381835938,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.13900375366211,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 100.14550018310547,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.33649826049805,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.08150100708008,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 97.59600067138672,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 104.82933553059895,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 114.43099721272786,
|
||||
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 110.40333302815755,
|
||||
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 567.2765197753906,
|
||||
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1032.5083312988281,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 852.7170003255209,
|
||||
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1361.954854329427,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 77.385498046875,
|
||||
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 265.0193354288737,
|
||||
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 115.31749725341797,
|
||||
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 245.27666727701822,
|
||||
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 71.75300216674805,
|
||||
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 141.8895009358724,
|
||||
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 71.15749994913737,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 90.59066772460938,
|
||||
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 173.73916625976562,
|
||||
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.65066655476888,
|
||||
"test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 99.21799850463867,
|
||||
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 90.86299896240234,
|
||||
"test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.57050196329753,
|
||||
"test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 69.65149958928426,
|
||||
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 78.13350168863933,
|
||||
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 76.85255601671007,
|
||||
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 333.04866282145184,
|
||||
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 146.96599833170572,
|
||||
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 160.4881100124783,
|
||||
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 124.10055626763238,
|
||||
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 117.38410907321506,
|
||||
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 710.2327779134115,
|
||||
"test_sort_stable_cpu (__main__.CpuTritonTests)": 1324.4399820963542,
|
||||
"test_sort_stable_cuda (__main__.GPUTests)": 76.83109970092774,
|
||||
"test_split_cumsum_cpu (__main__.CpuTritonTests)": 88.58433532714844,
|
||||
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 160.1271684964498,
|
||||
"test_tensor_split (__main__.TestVmapOperators)": 79.18955569393519,
|
||||
"test_terminate_handler_on_crash (__main__.TestTorch)": 111.30388899644215,
|
||||
"test_terminate_signal (__main__.ForkTest)": 132.3458870516883,
|
||||
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 132.2043343567186,
|
||||
"test_terminate_signal (__main__.SpawnTest)": 136.1005539894104,
|
||||
"test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 76.20899939537048,
|
||||
"test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 63.82099969046457,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 61.925000508626304,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 60.89849980672201,
|
||||
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 66.88233375549316,
|
||||
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.9854990641276,
|
||||
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.4044977823893,
|
||||
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 108.19166437784831,
|
||||
"test_unary_ops (__main__.TestTEFuserDynamic)": 96.32655514611139,
|
||||
"test_unary_ops (__main__.TestTEFuserStatic)": 105.33362591266632,
|
||||
"test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 97.8336664835612,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 82.86566925048828,
|
||||
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 68.26500002543132,
|
||||
"test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 97.1120007832845,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 88.24766794840495,
|
||||
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 65.41266759236653,
|
||||
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 74.75533294677734,
|
||||
"test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 73.52500089009602,
|
||||
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.85466639200847,
|
||||
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 98.39650090535481,
|
||||
"test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 61.39695285615467,
|
||||
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 77.88249842325847,
|
||||
"test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.0695006052653,
|
||||
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 81.86250114440918,
|
||||
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 98.63116455078125,
|
||||
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 94.85683314005534,
|
||||
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 173.00183614095053
|
||||
}
|
||||
@ -239,12 +239,6 @@ class TestAccelerator(TestCase):
|
||||
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
|
||||
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
|
||||
|
||||
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
|
||||
def test_get_memory_info(self):
|
||||
free_bytes, total_bytes = torch.accelerator.get_memory_info()
|
||||
self.assertGreaterEqual(free_bytes, 0)
|
||||
self.assertGreaterEqual(total_bytes, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -3728,6 +3728,7 @@ class TestSparse(TestSparseBase):
|
||||
@coalescedonoff
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@dtypesIfMPS(*all_mps_types())
|
||||
@expectedFailureMPS
|
||||
@dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater and not TEST_WITH_ROCM else [],
|
||||
*[torch.bfloat16] if SM80OrLater and not TEST_WITH_ROCM else [],
|
||||
torch.complex64,
|
||||
@ -3824,9 +3825,9 @@ class TestSparse(TestSparseBase):
|
||||
def different_dtypes():
|
||||
a, i_a, v_a = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
|
||||
b, i_b, v_b = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
|
||||
r2 = torch.sparse.mm(a.to(torch.float32), a.to(torch.float16))
|
||||
r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32))
|
||||
|
||||
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Float does not match mat2 dtype Half', different_dtypes)
|
||||
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
|
||||
|
||||
def test_backward_noncontiguous():
|
||||
# Sparse.mm backward used to wrong with non-contiguous grads,
|
||||
|
||||
@ -206,8 +206,7 @@ if __name__ == "__main__":
|
||||
test_multi_process(model, input)
|
||||
print(torch.xpu.device_count())
|
||||
"""
|
||||
# XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261
|
||||
rc = check_output(test_script).splitlines()[-1]
|
||||
rc = check_output(test_script)
|
||||
self.assertEqual(rc, str(torch.xpu.device_count()))
|
||||
|
||||
def test_streams(self):
|
||||
|
||||
@ -2491,7 +2491,6 @@ def _accelerator_emptyCache() -> None: ...
|
||||
def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ...
|
||||
def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ...
|
||||
def _accelerator_resetPeakStats(device_index: _int) -> None: ...
|
||||
def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ...
|
||||
def _accelerator_setAllocatorSettings(env: str) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
||||
|
||||
@ -550,10 +550,6 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
|
||||
).upper() # type: ignore[assignment]
|
||||
|
||||
cutedsl_enable_autotuning: bool = (
|
||||
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
|
||||
)
|
||||
|
||||
# DEPRECATED. This setting is ignored.
|
||||
autotune_fallback_to_aten = False
|
||||
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -14,7 +12,6 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
||||
from .. import config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import _IntLike, Layout, TensorBox
|
||||
from ..utils import load_template
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
from torch._inductor.runtime.triton_compat import tl
|
||||
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
@ -24,13 +22,11 @@ from ..utils import (
|
||||
get_num_sms,
|
||||
has_free_symbols,
|
||||
use_aten_gemm_kernels,
|
||||
use_blackwell_cutedsl_grouped_mm,
|
||||
use_triton_template,
|
||||
)
|
||||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
check_supported_striding,
|
||||
load_kernel_template,
|
||||
persistent_grouped_mm_grid,
|
||||
)
|
||||
|
||||
@ -517,11 +513,6 @@ triton_scaled_grouped_mm_template = TritonTemplate(
|
||||
source=triton_grouped_mm_source,
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template = CuteDSLTemplate(
|
||||
name="grouped_gemm_cutedsl",
|
||||
source=load_kernel_template("cutedsl_mm_grouped"),
|
||||
)
|
||||
|
||||
|
||||
def grouped_mm_args(
|
||||
mat1: TensorBox,
|
||||
@ -723,44 +714,43 @@ def _tuned_grouped_mm_common(
|
||||
# Checking only for the equality of corresponding dims of
|
||||
# multiplicands here, relying on meta function checks for
|
||||
# everything else.
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
if (
|
||||
is_nonzero
|
||||
and use_triton_template(layout)
|
||||
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
|
||||
):
|
||||
scaled = scale_a is not None
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
a_is_k_major = mat_a.get_stride()[-1] == 1
|
||||
b_is_k_major = mat_b.get_stride()[-2] == 1
|
||||
@ -798,22 +788,6 @@ def _tuned_grouped_mm_common(
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
if use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
|
||||
):
|
||||
for config in get_groupgemm_configs():
|
||||
kwargs = dict(
|
||||
ACC_DTYPE="cutlass.Float32",
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**asdict(config),
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
4: lambda x: create_offsets(
|
||||
x, m1_size, m2_size, offs.get_size() if offs is not None else None
|
||||
|
||||
@ -1,333 +0,0 @@
|
||||
import functools
|
||||
from torch._inductor.runtime.runtime_utils import ceildiv
|
||||
from cutlass.utils import TensorMapUpdateMode
|
||||
{{gen_defines()}}
|
||||
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
|
||||
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
|
||||
GroupedGemmKernel,
|
||||
)
|
||||
|
||||
|
||||
# Note about caching:
|
||||
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
|
||||
# maintains its own local caching system. At this stage, all compile-time
|
||||
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
|
||||
# name itself ({{kernel_name}}) are permanently baked into the file, so they
|
||||
# do not need to be included in any cache key.
|
||||
#
|
||||
# The caching mechanism is split into two levels:
|
||||
#
|
||||
# 1. prep_cache
|
||||
# Caches the compiled executor for build_group_ptrs_from_bases(). This
|
||||
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
|
||||
# and can therefore be safely reused across runs with different group
|
||||
# partitioning (`offs`).
|
||||
#
|
||||
# 2. gemm_cache
|
||||
# Caches the compiled Grouped GEMM executor. Its key extends the prep
|
||||
# cache key with hardware- and grid-specific parameters:
|
||||
# (prep_cache_key, max_active_clusters, total_num_clusters).
|
||||
# This is necessary because different `offs` tensors can change the
|
||||
# per-group problem sizes and thus alter `total_num_clusters`, which in
|
||||
# turn changes the grid shape and persistent scheduler configuration.
|
||||
# Kernels compiled for one grid cannot be safely reused for another.
|
||||
#
|
||||
#
|
||||
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
|
||||
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
|
||||
# despite depending only on the GPU type. We cache this function to mitigate
|
||||
# redundant recompiles even when shape/stride/dtype cache misses force kernel
|
||||
# regeneration. A follow-up study will investigate the root cause.
|
||||
|
||||
prep_cache = {}
|
||||
gemm_cache = {}
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_hardware_info():
|
||||
hw = cutlass.utils.HardwareInfo()
|
||||
sm_count = hw.get_max_active_clusters(1)
|
||||
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
|
||||
|
||||
return (sm_count, max_active_clusters)
|
||||
|
||||
|
||||
def get_prep_cache_key(input_a, input_b, output):
|
||||
"""
|
||||
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
|
||||
shapes, strides, and dtypes of input/output tensors.
|
||||
"""
|
||||
return (
|
||||
tuple(input_a.shape),
|
||||
tuple(input_a.stride()),
|
||||
input_a.dtype,
|
||||
tuple(input_b.shape),
|
||||
tuple(input_b.stride()),
|
||||
input_b.dtype,
|
||||
tuple(output.shape),
|
||||
tuple(output.stride()),
|
||||
output.dtype,
|
||||
)
|
||||
|
||||
|
||||
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
|
||||
"""
|
||||
Returns a tuple key for caching the gemm kernel executor by extending the
|
||||
prep cache key with hardware- and grid-specific parameters.
|
||||
"""
|
||||
return (
|
||||
prep_cache_key,
|
||||
max_active_clusters,
|
||||
total_num_clusters,
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
|
||||
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
|
||||
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
|
||||
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Int32, # bytes
|
||||
# -------- STRIDES (in ELEMENTS) --------
|
||||
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
|
||||
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
|
||||
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
|
||||
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
|
||||
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
|
||||
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
|
||||
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
|
||||
# -------- OUTPUTS --------
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
|
||||
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
g = tidx
|
||||
|
||||
m_beg_i32 = 0
|
||||
if g > 0:
|
||||
m_beg_i32 = offs[g - 1]
|
||||
m_end_i32 = offs[g]
|
||||
m_g_i32 = m_end_i32 - m_beg_i32
|
||||
|
||||
a_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
c_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
|
||||
|
||||
# ---- pointers ----
|
||||
out_ptrs[g, 0] = base_A_u64 + a_byte_off
|
||||
out_ptrs[g, 1] = base_B_u64 + b_byte_off
|
||||
out_ptrs[g, 2] = base_C_u64 + c_byte_off
|
||||
|
||||
# ---- (m, n, k, 1) ----
|
||||
out_problem[g, 0] = m_g_i32
|
||||
out_problem[g, 1] = N
|
||||
out_problem[g, 2] = K
|
||||
out_problem[g, 3] = cutlass.Int32(1)
|
||||
|
||||
# ---- strides ----
|
||||
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
|
||||
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
|
||||
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
|
||||
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
|
||||
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
|
||||
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_build_group_ptrs_from_bases(
|
||||
base_A_u64: cutlass.Int64,
|
||||
base_B_u64: cutlass.Int64,
|
||||
base_C_u64: cutlass.Int64,
|
||||
offs: cute.Tensor,
|
||||
G: cutlass.Constexpr,
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Constexpr,
|
||||
stride_A_m_elems: cutlass.Constexpr,
|
||||
stride_A_k_elems: cutlass.Constexpr,
|
||||
stride_B0_elems: cutlass.Constexpr,
|
||||
stride_Bk_elems: cutlass.Constexpr,
|
||||
stride_Bn_elems: cutlass.Constexpr,
|
||||
stride_C_m_elems: cutlass.Constexpr,
|
||||
stride_C_n_elems: cutlass.Constexpr,
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32
|
||||
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64,
|
||||
base_B_u64,
|
||||
base_C_u64,
|
||||
offs,
|
||||
K,
|
||||
N,
|
||||
sizeof_element,
|
||||
stride_A_m_elems,
|
||||
stride_A_k_elems,
|
||||
stride_B0_elems,
|
||||
stride_Bk_elems,
|
||||
stride_Bn_elems,
|
||||
stride_C_m_elems,
|
||||
stride_C_n_elems,
|
||||
out_ptrs,
|
||||
out_problem,
|
||||
out_strides_abc,
|
||||
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
|
||||
|
||||
|
||||
{{def_kernel("input_a", "input_b", "input_a_offs")}}
|
||||
stream = cuda.CUstream(stream)
|
||||
|
||||
input_b = input_b.transpose(1, 2)
|
||||
|
||||
sumM, K = input_a.shape
|
||||
G, N, Kb = input_b.shape
|
||||
|
||||
dev = input_a.device
|
||||
|
||||
base_A_u64 = int(input_a.data_ptr())
|
||||
base_B_u64 = int(input_b.data_ptr())
|
||||
base_C_u64 = int({{get_output()}}.data_ptr())
|
||||
|
||||
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
|
||||
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
|
||||
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
|
||||
ptrs = from_dlpack(ptrs_t)
|
||||
probs = from_dlpack(probs_t)
|
||||
strides = from_dlpack(strides_t)
|
||||
|
||||
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
|
||||
prep_executor = prep_cache.get(prep_cache_key)
|
||||
|
||||
if prep_executor is None:
|
||||
sizeof_element = int(input_a.element_size())
|
||||
sA_m, sA_k = map(int, input_a.stride())
|
||||
sB_0, sB_n, sB_k = map(int, input_b.stride())
|
||||
sC_m, sC_n = map(int, {{get_output()}}.stride())
|
||||
|
||||
prep_executor = cute.compile(
|
||||
launch_build_group_ptrs_from_bases,
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
G=int(G),
|
||||
K=int(K),
|
||||
N=int(N),
|
||||
sizeof_element=sizeof_element,
|
||||
stride_A_m_elems=sA_m,
|
||||
stride_A_k_elems=sA_k,
|
||||
stride_B0_elems=sB_0,
|
||||
stride_Bk_elems=sB_k,
|
||||
stride_Bn_elems=sB_n,
|
||||
stride_C_m_elems=sC_m,
|
||||
stride_C_n_elems=sC_n,
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
prep_cache[prep_cache_key] = prep_executor
|
||||
|
||||
prep_executor(
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# --- Tensormap workspace per SM ---
|
||||
num_tensormap_buffers, max_active_clusters = get_hardware_info()
|
||||
tensormap_shape = (
|
||||
num_tensormap_buffers,
|
||||
GroupedGemmKernel.num_tensormaps,
|
||||
GroupedGemmKernel.bytes_per_tensormap // 8,
|
||||
)
|
||||
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
|
||||
tensormap_workspace = from_dlpack(tensormap_workspace_t)
|
||||
|
||||
# --- Total clusters ---
|
||||
def compute_total_num_clusters(
|
||||
problem_sizes_mnkl,
|
||||
cluster_tile_shape_mn,
|
||||
):
|
||||
total_num_clusters = 0
|
||||
for m, n, _, _ in problem_sizes_mnkl:
|
||||
num_clusters_mn = tuple(
|
||||
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
|
||||
)
|
||||
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
|
||||
return total_num_clusters
|
||||
|
||||
# Compute cluster tile shape
|
||||
def compute_cluster_tile_shape(
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
use_2cta_instrs,
|
||||
):
|
||||
cta_tile_shape_mn = list(mma_tiler_mn)
|
||||
if use_2cta_instrs:
|
||||
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
|
||||
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
|
||||
|
||||
cluster_tile_shape_mn = compute_cluster_tile_shape(
|
||||
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
|
||||
)
|
||||
|
||||
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
|
||||
|
||||
gemm_cache_key = get_gemm_cache_key(
|
||||
prep_cache_key, max_active_clusters, total_num_clusters
|
||||
)
|
||||
gemm_executor = gemm_cache.get(gemm_cache_key)
|
||||
|
||||
if gemm_executor is None:
|
||||
grouped_gemm = GroupedGemmKernel(
|
||||
acc_dtype=ACC_DTYPE,
|
||||
use_2cta_instrs=USE_2_CTA,
|
||||
mma_tiler_mn=(TILE_M, TILE_N),
|
||||
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
|
||||
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
|
||||
)
|
||||
|
||||
gemm_executor = cute.compile(
|
||||
grouped_gemm,
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
G,
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
total_num_clusters,
|
||||
tensormap_workspace,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
|
||||
gemm_cache[gemm_cache_key] = gemm_executor
|
||||
|
||||
gemm_executor(
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
tensormap_workspace,
|
||||
stream,
|
||||
)
|
||||
@ -1,141 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from itertools import product
|
||||
|
||||
import torch._inductor.config as config
|
||||
|
||||
|
||||
class TensorMapUpdateMode(Enum):
|
||||
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
|
||||
|
||||
SMEM = auto()
|
||||
GMEM = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CuTeGemmConfig:
|
||||
TILE_M: int = 128
|
||||
TILE_N: int = 192
|
||||
CLUSTER_M: int = 2
|
||||
CLUSTER_N: int = 1
|
||||
USE_2_CTA: bool = False
|
||||
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
|
||||
|
||||
|
||||
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
For information regarding valid config sets, see:
|
||||
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
|
||||
"""
|
||||
|
||||
# Tile_n is always the same regardless of 2cta
|
||||
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
# Valid clusters
|
||||
clusters_no_2cta = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(1, 8),
|
||||
(1, 16),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
clusters_2cta = [
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
|
||||
configs: list[CuTeGemmConfig] = []
|
||||
|
||||
for use_2cta, cluster_set, tile_m_range in [
|
||||
(False, clusters_no_2cta, [64, 128]),
|
||||
(True, clusters_2cta, [128, 256]),
|
||||
]:
|
||||
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
|
||||
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
|
||||
tile_m_range,
|
||||
tile_n_vals,
|
||||
cluster_set,
|
||||
):
|
||||
configs.append(
|
||||
CuTeGemmConfig(
|
||||
tile_m,
|
||||
tile_n,
|
||||
cluster_m,
|
||||
cluster_n,
|
||||
USE_2_CTA=use_2cta,
|
||||
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
"""
|
||||
|
||||
config_tuples = [
|
||||
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
]
|
||||
|
||||
return [CuTeGemmConfig(*args) for args in config_tuples]
|
||||
|
||||
|
||||
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
|
||||
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
|
||||
or unstable results. By default, autotuning is disabled and we return only
|
||||
a single baseline config.
|
||||
"""
|
||||
if (
|
||||
config.cutedsl_enable_autotuning
|
||||
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
):
|
||||
return get_exhaustive_groupgemm_configs()
|
||||
elif config.cutedsl_enable_autotuning:
|
||||
return get_default_groupgemm_configs()
|
||||
else:
|
||||
return [get_default_groupgemm_configs()[0]]
|
||||
@ -1911,84 +1911,6 @@ def use_triton_blackwell_tma_template(
|
||||
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def ensure_cute_available() -> bool:
|
||||
"""Check if CuTeDSL is importable; cache the result for reuse.
|
||||
|
||||
Call ensure_cute_available.cache_clear() after installing CuTeDSL
|
||||
in the same interpreter to retry the import.
|
||||
"""
|
||||
try:
|
||||
return importlib.util.find_spec("cutlass.cute") is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a: Any,
|
||||
mat_b: Any,
|
||||
layout: Layout,
|
||||
a_is_2d: bool,
|
||||
b_is_2d: bool,
|
||||
offs: Optional[Any],
|
||||
bias: Optional[Any],
|
||||
scale_result: Optional[Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we can use the blackwell kernel for grouped mm.
|
||||
Required conditions:
|
||||
1. CuTeDSL backend is enabled
|
||||
2. CuTeDSL is available
|
||||
3. We are on a blackwell arch
|
||||
4. The dtype is bf16
|
||||
5. Max autotune or max autotune gemm is enabled
|
||||
6. A, B, and the output are 16B aligned
|
||||
7. We are not using dynamic shapes
|
||||
8. A is 2d
|
||||
9. B is 3d
|
||||
10. Offsets are provided
|
||||
11. Bias and Scale are not provided
|
||||
"""
|
||||
if not ensure_cute_available():
|
||||
return False
|
||||
|
||||
if not _use_autotune_backend("CUTEDSL"):
|
||||
return False
|
||||
|
||||
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
|
||||
if not is_gpu(layout.device.type):
|
||||
return False
|
||||
|
||||
if not is_datacenter_blackwell_arch():
|
||||
return False
|
||||
|
||||
layout_dtypes = [torch.bfloat16]
|
||||
if not _use_template_for_gpu(layout, layout_dtypes):
|
||||
return False
|
||||
|
||||
if not (config.max_autotune or config.max_autotune_gemm):
|
||||
return False
|
||||
|
||||
# Checks for 16B ptr and stride alignment
|
||||
if not can_use_tma(mat_a, mat_b, output_layout=layout):
|
||||
return False
|
||||
|
||||
if any(is_dynamic(x) for x in [mat_a, mat_b]):
|
||||
return False
|
||||
|
||||
if not a_is_2d or b_is_2d:
|
||||
return False
|
||||
|
||||
if offs is None:
|
||||
return False
|
||||
|
||||
if bias is not None or scale_result is not None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ import torch
|
||||
from ._utils import _device_t, _get_device_index
|
||||
from .memory import (
|
||||
empty_cache,
|
||||
get_memory_info,
|
||||
max_memory_allocated,
|
||||
max_memory_reserved,
|
||||
memory_allocated,
|
||||
@ -26,10 +25,9 @@ __all__ = [
|
||||
"current_device_idx", # deprecated
|
||||
"current_device_index",
|
||||
"current_stream",
|
||||
"empty_cache",
|
||||
"device_count",
|
||||
"device_index",
|
||||
"empty_cache",
|
||||
"get_memory_info",
|
||||
"is_available",
|
||||
"max_memory_allocated",
|
||||
"max_memory_reserved",
|
||||
|
||||
@ -8,7 +8,6 @@ from ._utils import _device_t, _get_device_index
|
||||
|
||||
__all__ = [
|
||||
"empty_cache",
|
||||
"get_memory_info",
|
||||
"max_memory_allocated",
|
||||
"max_memory_reserved",
|
||||
"memory_allocated",
|
||||
@ -88,9 +87,6 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values.
|
||||
"""
|
||||
if not torch._C._accelerator_isAllocatorInitialized():
|
||||
return OrderedDict()
|
||||
@ -121,9 +117,6 @@ def memory_allocated(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the current memory occupied by live tensors (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("allocated_bytes.all.current", 0)
|
||||
|
||||
@ -141,9 +134,6 @@ def max_memory_allocated(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the peak memory occupied by live tensors (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
|
||||
|
||||
@ -157,9 +147,6 @@ def memory_reserved(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the current memory reserved by PyTorch (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("reserved_bytes.all.current", 0)
|
||||
|
||||
@ -177,9 +164,6 @@ def max_memory_reserved(device_index: _device_t = None, /) -> int:
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
int: the peak memory reserved by PyTorch (in bytes) within the current process.
|
||||
"""
|
||||
return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
|
||||
|
||||
@ -216,21 +200,3 @@ def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
|
||||
"""
|
||||
device_index = _get_device_index(device_index, optional=True)
|
||||
return torch._C._accelerator_resetPeakStats(device_index)
|
||||
|
||||
|
||||
def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]:
|
||||
r"""Return the current device memory information for a given device index.
|
||||
|
||||
Args:
|
||||
device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
|
||||
If not given, use :func:`torch.accelerator.current_device_index` by default.
|
||||
If a :class:`torch.device` or str is provided, its type must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes.
|
||||
The first value is the free memory on the device (available across all processes and applications),
|
||||
The second value is the device's total hardware memory capacity.
|
||||
"""
|
||||
device_index = _get_device_index(device_index, optional=True)
|
||||
return torch._C._accelerator_getMemoryInfo(device_index)
|
||||
|
||||
@ -195,12 +195,10 @@ def get_new_attr_name_with_prefix(prefix: str) -> Callable:
|
||||
def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
|
||||
r"""Starting from a target node, trace back until we hit input or
|
||||
getattr node. This is used to extract the chain of operators
|
||||
starting from getattr to the target node, for example::
|
||||
|
||||
def forward(self, x):
|
||||
observed = self.observer(self.weight)
|
||||
return F.linear(x, observed)
|
||||
|
||||
starting from getattr to the target node, for example
|
||||
def forward(self, x):
|
||||
observed = self.observer(self.weight)
|
||||
return F.linear(x, observed)
|
||||
collect_producer_nodes(observed) will either return a list of nodes that
|
||||
produces the observed node or None if we can't extract a self contained
|
||||
graph without free variables(inputs of the forward function).
|
||||
|
||||
@ -138,13 +138,6 @@ void initModule(PyObject* module) {
|
||||
at::accelerator::resetPeakStats(device_index);
|
||||
});
|
||||
|
||||
m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
||||
const auto device_type = at::accelerator::getAccelerator(true).value();
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
py::gil_scoped_release no_gil;
|
||||
return at::accelerator::getMemoryInfo(device_index);
|
||||
});
|
||||
|
||||
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
});
|
||||
|
||||
@ -386,8 +386,23 @@ static void bindGetDeviceProperties(PyObject* module) {
|
||||
static void initXpuMethodBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
|
||||
py::gil_scoped_release no_gil;
|
||||
return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index);
|
||||
#if SYCL_COMPILER_VERSION >= 20250000
|
||||
auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size;
|
||||
auto& device = c10::xpu::get_raw_device(device_index);
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
at::xpu::getDeviceProperties(device_index)->name,
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
auto free = device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return std::make_tuple(free, total);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
|
||||
#endif
|
||||
});
|
||||
m.def(
|
||||
"_xpu_getStreamFromExternal",
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
@ -8,7 +7,6 @@ import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor.placement_types import (
|
||||
_StridedShard,
|
||||
MaskPartial,
|
||||
Partial,
|
||||
Placement,
|
||||
Replicate,
|
||||
@ -129,185 +127,6 @@ class DTensorSpec:
|
||||
)
|
||||
return default_shard_order
|
||||
|
||||
@staticmethod
|
||||
def _convert_shard_order_to_StridedShard(
|
||||
shard_order: ShardOrder, placements: tuple[Placement, ...], mesh: DeviceMesh
|
||||
) -> tuple[Placement, ...]:
|
||||
"""
|
||||
Convert ShardOrder to placements with _StridedShard.
|
||||
|
||||
This function converts a ShardOrder specification into a tuple of Placement objects,
|
||||
using _StridedShard when a tensor dimension is sharded across multiple mesh dimensions
|
||||
in a non-default order. The split_factor of each _StridedShard is determined by the
|
||||
product of mesh dimension sizes that appear earlier in the shard order but later in
|
||||
the placement tuple.
|
||||
|
||||
Args:
|
||||
shard_order: ShardOrder specification indicating which tensor dimensions are
|
||||
sharded on which mesh dimensions and in what execution order.
|
||||
placements: Tuple of Placement objects that does not contain _StridedShard.
|
||||
mesh: DeviceMesh containing the size information for each mesh dimension.
|
||||
|
||||
Returns:
|
||||
Updated tuple of Placement objects with Shard or _StridedShard placements.
|
||||
|
||||
Algorithm:
|
||||
For each ShardOrderEntry in shard_order:
|
||||
- For each mesh dimension in the entry's mesh_dims (in order):
|
||||
- Calculate split_factor as the product of mesh sizes for all mesh dimensions
|
||||
that appear:
|
||||
1. Earlier in the shard order (lower index in mesh_dims), and
|
||||
2. Later in the placement tuple (higher mesh dimension index)
|
||||
- If split_factor == 1: use normal Shard
|
||||
- Otherwise: use _StridedShard with the calculated split_factor
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("Requires DeviceMesh")
|
||||
>>> # Tensor dimension 0 sharded on mesh dims [2, 0, 1] in that order
|
||||
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
|
||||
>>> shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),)
|
||||
>>> placements = (Shard(0), Shard(0), Shard(0))
|
||||
>>> # For mesh_dim=2 (index 0 in mesh_dims): no earlier dims, split_factor=1
|
||||
>>> # -> placements[2] = Shard(0)
|
||||
>>> # For mesh_dim=0 (index 1 in mesh_dims): mesh_dim=2 is earlier and has index 2>0
|
||||
>>> # -> split_factor = mesh.size(2) = 2
|
||||
>>> # -> placements[0] = _StridedShard(0, split_factor=2)
|
||||
>>> # For mesh_dim=1 (index 2 in mesh_dims): mesh_dim=2 is earlier and has index 2>1
|
||||
>>> # -> split_factor = mesh.size(2) = 2
|
||||
>>> # -> placements[1] = _StridedShard(0, split_factor=2)
|
||||
>>> # Result: (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
|
||||
"""
|
||||
placements_list = list(placements)
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
for idx in range(len(mesh_dims)):
|
||||
# TODO(zpcore): split_factor from `view` and `shard order`
|
||||
# should be able to be multiplied into one. Need to loosen the
|
||||
# condition here.
|
||||
mesh_dim = mesh_dims[idx]
|
||||
if type(placements[mesh_dim]) is not Shard:
|
||||
raise ValueError(
|
||||
f"Only Shard placement can be converted to _StridedShard, "
|
||||
f"found {placements[mesh_dim]} in {placements=}."
|
||||
)
|
||||
split_factor = math.prod(
|
||||
mesh.size(i) for i in mesh_dims[:idx] if i > mesh_dim
|
||||
)
|
||||
if split_factor == 1:
|
||||
# use normal Shard
|
||||
placements_list[mesh_dim] = Shard(tensor_dim)
|
||||
else:
|
||||
placements_list[mesh_dim] = _StridedShard(
|
||||
tensor_dim, split_factor=split_factor
|
||||
)
|
||||
return tuple(placements_list)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_convert_StridedShard_to_shard_order(
|
||||
placements: tuple[Placement, ...], mesh: DeviceMesh
|
||||
) -> Optional[ShardOrder]:
|
||||
"""
|
||||
Try to convert _StridedShard placements to ShardOrder.
|
||||
|
||||
This is the inverse of `_convert_shard_order_to_StridedShard`. It reconstructs the shard
|
||||
order by examining the split_factor of each _StridedShard and determining its position
|
||||
in the execution order. If the _StridedShard configuration cannot be represented as a
|
||||
valid ShardOrder (i.e., there's no shard order that produces the observed split_factors),
|
||||
this function returns None.
|
||||
|
||||
Args:
|
||||
placements: Tuple of Placement objects that may contain _StridedShard.
|
||||
mesh: DeviceMesh containing the size information for each mesh dimension.
|
||||
|
||||
Returns:
|
||||
ShardOrder if conversion is possible, None otherwise. For placements without
|
||||
_StridedShard, returns the default shard order.
|
||||
|
||||
Algorithm:
|
||||
1. If no _StridedShard in placements, return default shard order
|
||||
2. Create an empty list for each tensor dimension to represent mesh dim ordering
|
||||
3. Iterate through placements in reverse order (right to left):
|
||||
- For each Shard/_StridedShard on a tensor dimension:
|
||||
- Extract its split_factor (1 for Shard, split_factor for _StridedShard)
|
||||
- Find the position in mesh_dims_order where accumulated_sf equals split_factor
|
||||
- accumulated_sf is the product of mesh sizes of mesh dimensions that appear
|
||||
earlier in mesh_dims_order (lower indices)
|
||||
- Insert mesh_dim at the found position
|
||||
4. If no valid position found for any split_factor, return None (unable to convert)
|
||||
5. Construct ShardOrderEntry for each tensor dimension from mesh_dims_order
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("Requires DeviceMesh")
|
||||
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
|
||||
>>> # placements = (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
|
||||
>>> # Process tensor_dim=0 from right to left:
|
||||
>>> # - mesh_dim=2: Shard(0) with sf=1
|
||||
>>> # Try position 0: accumulated_sf=1, matches! Insert at position 0
|
||||
>>> # Current mesh_dims_order order: [2]
|
||||
>>> # - mesh_dim=1: _StridedShard(0, sf=2) with sf=2
|
||||
>>> # Try position 0: accumulated_sf=1, no match
|
||||
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
|
||||
>>> # Current mesh_dims_order order: [2, 1]
|
||||
>>> # - mesh_dim=0: _StridedShard(0, sf=2) with sf=2
|
||||
>>> # Try position 0: accumulated_sf=1, no match
|
||||
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
|
||||
>>> # Final mesh_dims_order order: [2, 0, 1]
|
||||
>>> # Result: ShardOrder((ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),))
|
||||
>>> # This means: first shard on mesh_dim=2, then mesh_dim=0, then mesh_dim=1
|
||||
|
||||
Note:
|
||||
This function validates that _StridedShard can be represented as a ShardOrder.
|
||||
Not all _StridedShard configurations are valid - the split_factor must match
|
||||
the product of mesh sizes in some execution order.
|
||||
"""
|
||||
if not any(isinstance(p, _StridedShard) for p in placements):
|
||||
return DTensorSpec.compute_default_shard_order(placements)
|
||||
max_tensor_dim = (
|
||||
max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1
|
||||
)
|
||||
shard_order = []
|
||||
|
||||
tensor_dim_to_mesh_dims_order: list[list[int]] = [
|
||||
[] for i in range(max_tensor_dim)
|
||||
]
|
||||
for mesh_dim in reversed(range(len(placements))):
|
||||
cur_placement = placements[mesh_dim]
|
||||
# _StridedShard may not be a subclass of Shard in the future, so write in this way:
|
||||
if isinstance(cur_placement, Shard | _StridedShard):
|
||||
tensor_dim = cur_placement.dim
|
||||
mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim]
|
||||
cur_sf = 1
|
||||
if isinstance(cur_placement, _StridedShard):
|
||||
cur_sf = cur_placement.split_factor
|
||||
accumulated_sf = 1
|
||||
find_order = False
|
||||
for i in range(len(mesh_dims_order) + 1):
|
||||
if accumulated_sf == cur_sf:
|
||||
mesh_dims_order.insert(i, mesh_dim)
|
||||
find_order = True
|
||||
break
|
||||
if i < len(mesh_dims_order):
|
||||
accumulated_sf *= mesh.size(mesh_dims_order[i])
|
||||
if not find_order:
|
||||
# _StridedShard is not convertible to ShardOrder
|
||||
return None
|
||||
else:
|
||||
if not isinstance(cur_placement, Replicate | Partial | MaskPartial):
|
||||
raise ValueError(
|
||||
f"Unsupported placement type {type(cur_placement)} encountered in "
|
||||
f"{placements}; expected Replicate, Partial, or MaskPartial."
|
||||
)
|
||||
for tensor_dim in range(max_tensor_dim):
|
||||
if len(tensor_dim_to_mesh_dims_order[tensor_dim]) > 0:
|
||||
shard_order.append(
|
||||
ShardOrderEntry(
|
||||
tensor_dim=tensor_dim,
|
||||
mesh_dims=tuple(tensor_dim_to_mesh_dims_order[tensor_dim]),
|
||||
)
|
||||
)
|
||||
return tuple(shard_order)
|
||||
|
||||
def _verify_shard_order(self, shard_order: ShardOrder) -> None:
|
||||
"""Verify that the shard_order is valid and matches the placements."""
|
||||
total_shard = 0
|
||||
|
||||
@ -11,10 +11,7 @@ import torch.distributed.tensor._api as dtensor
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def _requires_data_exchange(padding, dim_map) -> bool:
|
||||
# Data exchange is not need if only sharded across batch dim
|
||||
if all(x == -1 for x in dim_map[1:]):
|
||||
return False
|
||||
def _requires_data_exchange(padding):
|
||||
# TODO: whether there requires data exchange is currently determined by padding
|
||||
return padding[-1] != 0
|
||||
|
||||
@ -110,7 +107,6 @@ def tp_convolution(
|
||||
op_call: torch._ops.OpOverload,
|
||||
local_tensor_args: tuple[object, ...],
|
||||
local_tensor_kwargs: dict[str, object],
|
||||
dim_map: list[int],
|
||||
) -> object:
|
||||
assert op_call == aten.convolution.default
|
||||
assert len(local_tensor_args) == 9
|
||||
@ -124,7 +120,7 @@ def tp_convolution(
|
||||
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
|
||||
assert isinstance(padding, list)
|
||||
|
||||
if not _requires_data_exchange(padding, dim_map):
|
||||
if not _requires_data_exchange(padding):
|
||||
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
|
||||
return local_results
|
||||
else:
|
||||
@ -164,7 +160,6 @@ def tp_convolution_backward(
|
||||
op_call: torch._ops.OpOverload,
|
||||
local_tensor_args: tuple[object, ...],
|
||||
local_tensor_kwargs: dict[str, object],
|
||||
dim_map: list[int],
|
||||
) -> object:
|
||||
assert op_call == aten.convolution_backward.default
|
||||
assert len(local_tensor_args) == 11
|
||||
@ -179,7 +174,7 @@ def tp_convolution_backward(
|
||||
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
|
||||
assert isinstance(padding, list)
|
||||
|
||||
if not _requires_data_exchange(padding, dim_map):
|
||||
if not _requires_data_exchange(padding):
|
||||
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
|
||||
return local_results
|
||||
else:
|
||||
@ -244,18 +239,15 @@ def convolution_handler(
|
||||
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
|
||||
output_sharding = op_info.output_sharding
|
||||
assert output_sharding is not None, "output sharding should not be None"
|
||||
output_spec = output_sharding.output_spec
|
||||
assert isinstance(output_spec, dtensor.DTensorSpec)
|
||||
|
||||
# local propagation
|
||||
local_results = tp_convolution(
|
||||
op_call,
|
||||
tuple(op_info.local_args),
|
||||
op_info.local_kwargs,
|
||||
output_spec.dim_map,
|
||||
op_call, tuple(op_info.local_args), op_info.local_kwargs
|
||||
)
|
||||
|
||||
return dtensor.DTensor._op_dispatcher.wrap(local_results, output_spec)
|
||||
return dtensor.DTensor._op_dispatcher.wrap(
|
||||
local_results, output_sharding.output_spec
|
||||
)
|
||||
|
||||
|
||||
def convolution_backward_handler(
|
||||
@ -278,14 +270,10 @@ def convolution_backward_handler(
|
||||
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
|
||||
output_sharding = op_info.output_sharding
|
||||
assert output_sharding is not None, "output sharding should not be None"
|
||||
assert isinstance(op_info.flat_args_schema[0], dtensor.DTensorSpec)
|
||||
|
||||
# local propagation
|
||||
local_results = tp_convolution_backward(
|
||||
op_call,
|
||||
tuple(op_info.local_args),
|
||||
op_info.local_kwargs,
|
||||
op_info.flat_args_schema[0].dim_map,
|
||||
op_call, tuple(op_info.local_args), op_info.local_kwargs
|
||||
)
|
||||
|
||||
return dtensor.DTensor._op_dispatcher.wrap(
|
||||
|
||||
@ -20320,7 +20320,6 @@ op_db: list[OpInfo] = [
|
||||
torch.float32: 1e-4}),),
|
||||
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
|
||||
supports_sparse=True,
|
||||
supports_sparse_csr=True,
|
||||
supports_sparse_csc=True,
|
||||
supports_sparse_bsr=True,
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
import sys
|
||||
@ -33,8 +32,6 @@ from torch.distributed.tensor import (
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
@ -821,125 +818,3 @@ def map_local_for_rank(rank, func):
|
||||
|
||||
def reduce_local_int(val, func):
|
||||
return func(val.node._local_ints)
|
||||
|
||||
|
||||
def _convert_shard_order_dict_to_ShardOrder(shard_order):
|
||||
"""Convert shard_order dict to ShardOrder"""
|
||||
return tuple(
|
||||
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
|
||||
for tensor_dim, mesh_dims in shard_order.items()
|
||||
)
|
||||
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def redistribute(
|
||||
dtensor_input,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""
|
||||
wrapper function to support shard_order for redistribution
|
||||
This is a simpler version of Redistribute, only considers the forward.
|
||||
"""
|
||||
if placements is None:
|
||||
placements = shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
old_spec = dtensor_input._spec
|
||||
new_spec = copy.deepcopy(old_spec)
|
||||
new_spec.placements = placements
|
||||
if shard_order is not None:
|
||||
new_spec.shard_order = shard_order
|
||||
else:
|
||||
new_spec.shard_order = ()
|
||||
if old_spec == new_spec:
|
||||
return dtensor_input
|
||||
dtensor_input = DTensor.from_local(
|
||||
redistribute_local_tensor(
|
||||
dtensor_input.to_local(),
|
||||
old_spec,
|
||||
new_spec,
|
||||
use_graph_based_transform=use_graph_based_transform,
|
||||
),
|
||||
device_mesh,
|
||||
)
|
||||
dtensor_input._spec = copy.deepcopy(new_spec)
|
||||
return dtensor_input # returns DTensor
|
||||
|
||||
|
||||
# TODO(zpcore): remove once the native distribute_tensor supports
|
||||
# shard_order arg
|
||||
def patched_distribute_tensor(
|
||||
input_tensor,
|
||||
device_mesh,
|
||||
placements,
|
||||
shard_order,
|
||||
use_graph_based_transform=True,
|
||||
):
|
||||
"""wrapper function to support shard_order for tensor distribution"""
|
||||
if placements is None:
|
||||
placements = shard_order_to_placement(shard_order, device_mesh)
|
||||
placements = tuple(placements)
|
||||
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
|
||||
# fix the shard order
|
||||
return redistribute(
|
||||
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
|
||||
)
|
||||
|
||||
|
||||
# TODO(zpcore): remove once the native redistribute supports shard_order arg
|
||||
def make_full_tensor(dtensor_input):
|
||||
"""wrapper function to support DTensor.full_tensor"""
|
||||
return redistribute(
|
||||
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
|
||||
).to_local()
|
||||
|
||||
|
||||
def shard_order_to_placement(shard_order, mesh):
|
||||
"""convert shard_order to placement with only Replicate() and Shard()"""
|
||||
placements: list[Any] = [Replicate() for _ in range(mesh.ndim)]
|
||||
if shard_order is not None:
|
||||
for entry in shard_order:
|
||||
tensor_dim = entry.tensor_dim
|
||||
mesh_dims = entry.mesh_dims
|
||||
for mesh_dim in mesh_dims:
|
||||
placements[mesh_dim] = Shard(tensor_dim)
|
||||
return tuple(placements)
|
||||
|
||||
|
||||
def generate_shard_orders(mesh, tensor_rank):
|
||||
# Generate all possible sharding placement of tensor with rank
|
||||
# `tensor_rank` over mesh.
|
||||
def _split_list(lst: list, N: int):
|
||||
def compositions(n: int, k: int):
|
||||
# yields lists of length k, positive ints summing to n
|
||||
for cuts in itertools.combinations(range(1, n), k - 1):
|
||||
# add 0 and n as sentinels, then take consecutive differences
|
||||
yield [b - a for a, b in itertools.pairwise((0, *cuts, n))]
|
||||
|
||||
length = len(lst)
|
||||
for comp in compositions(length, N):
|
||||
result = []
|
||||
start = 0
|
||||
for size in comp:
|
||||
result.append(lst[start : start + size])
|
||||
start += size
|
||||
yield result
|
||||
|
||||
all_mesh = list(range(mesh.ndim))
|
||||
all_device_order = list(itertools.permutations(all_mesh))
|
||||
for device_order in all_device_order:
|
||||
# split on device orders, and assign each device order segment to a tensor dim
|
||||
for num_split in range(1, mesh.ndim + 1):
|
||||
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
|
||||
for tensor_dims in itertools.combinations(
|
||||
range(tensor_rank), len(splitted_list)
|
||||
):
|
||||
shard_order = {}
|
||||
assert len(tensor_dims) == len(splitted_list)
|
||||
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
|
||||
shard_order[tensor_dim] = device_order[
|
||||
mesh_dims[0] : mesh_dims[-1] + 1
|
||||
]
|
||||
yield _convert_shard_order_dict_to_ShardOrder(shard_order)
|
||||
|
||||
@ -529,14 +529,11 @@ RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+
|
||||
|
||||
|
||||
def replace_extern_shared(input_string):
|
||||
"""
|
||||
Match 'extern __shared__ type foo[];' syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
|
||||
See: https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
|
||||
Examples:
|
||||
"extern __shared__ char smemChar[];"
|
||||
=> "HIP_DYNAMIC_SHARED( char, smemChar)"
|
||||
"extern __shared__ unsigned char smem[];"
|
||||
=> "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
|
||||
"""Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
|
||||
https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
|
||||
Example:
|
||||
"extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
|
||||
"extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
|
||||
"""
|
||||
output_string = input_string
|
||||
output_string = RE_EXTERN_SHARED.sub(
|
||||
@ -1046,17 +1043,14 @@ RE_INCLUDE = re.compile(r"#include .*\n")
|
||||
|
||||
|
||||
def extract_arguments(start, string):
|
||||
"""
|
||||
Return the list of arguments in the upcoming function parameter closure.
|
||||
Example:
|
||||
""" Return the list of arguments in the upcoming function parameter closure.
|
||||
Example:
|
||||
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
|
||||
arguments (output):
|
||||
[
|
||||
{'start': 1, 'end': 7},
|
||||
{'start': 8, 'end': 16},
|
||||
{'start': 17, 'end': 19},
|
||||
{'start': 20, 'end': 53}
|
||||
]
|
||||
'[{'start': 1, 'end': 7},
|
||||
{'start': 8, 'end': 16},
|
||||
{'start': 17, 'end': 19},
|
||||
{'start': 20, 'end': 53}]'
|
||||
"""
|
||||
|
||||
arguments = []
|
||||
|
||||
@ -190,7 +190,6 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
|
||||
int: the memory available on the device in units of bytes.
|
||||
int: the total memory on the device in units of bytes
|
||||
"""
|
||||
_lazy_init()
|
||||
device = _get_device_index(device, optional=True)
|
||||
return torch._C._xpu_getMemoryInfo(device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user