Compare commits

..

6 Commits

Author SHA1 Message Date
ebf9f17e32 Update base for Update on "[Inductor] Prevent large fusion for aten::cat"
Same context as in https://github.com/pytorch/pytorch/pull/166275 :
MTIA triton currently has a limit that it can't support the cases when there are too many input/output buffers. 

This PR adds the limitation to prevent large fusion with many input/output buffer, specifically for aten::cat operator.


in addition to D84571005

Differential Revision: [D84569512](https://our.internmc.facebook.com/intern/diff/D84569512/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-10 16:45:14 -08:00
68298c8473 Update base for Update on "[Inductor] Prevent large fusion for aten::cat"
Same context as in https://github.com/pytorch/pytorch/pull/166275 :
MTIA triton currently has a limit that it can't support the cases when there are too many input/output buffers. 

This PR adds the limitation to prevent large fusion with many input/output buffer, specifically for aten::cat operator.


in addition to D84571005

Differential Revision: [D84569512](https://our.internmc.facebook.com/intern/diff/D84569512/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-11-10 12:06:45 -08:00
e9860da56e Update base for Update on "[Inductor] Prevent large fusion for aten::cat"
Same context as in https://github.com/pytorch/pytorch/pull/166275 :
MTIA triton currently has a limit that it can't support the cases when there are too many input/output buffers. 

This PR adds the limitation to prevent large fusion with many input/output buffer, specifically for aten::cat operator.


in addition to D84571005

Differential Revision: [D84569512](https://our.internmc.facebook.com/intern/diff/D84569512/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-30 10:46:52 -07:00
023c7f344e Update base for Update on "[Inductor] Prevent large fusion for aten::cat"
Same context as in https://github.com/pytorch/pytorch/pull/166275 :
MTIA triton currently has a limit that it can't support the cases when there are too many input/output buffers. 

This PR adds the limitation to prevent large fusion with many input/output buffer, specifically for aten::cat operator.


in addition to D84571005

Differential Revision: [D84569512](https://our.internmc.facebook.com/intern/diff/D84569512/)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-27 15:56:31 -07:00
ff752966ad [Inductor] Prevent kernel fusion with too many unique inputs and outputs
Differential Revision: [D85509351](https://our.internmc.facebook.com/intern/diff/D85509351/)

[ghstack-poisoned]
2025-10-26 15:08:20 -07:00
8e29b50854 [Inductor] Make combo kernel MAX_NUM_ARGS configurable
Differential Revision: [D85509352](https://our.internmc.facebook.com/intern/diff/D85509352/)

[ghstack-poisoned]
2025-10-26 15:08:17 -07:00
23 changed files with 968 additions and 1053 deletions

View File

@ -1 +1 @@
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9

View File

@ -4292,7 +4292,6 @@
dispatch:
SparseCPU: sparse_sparse_matmul_cpu
SparseCUDA: sparse_sparse_matmul_cuda
SparseMPS: sparse_sparse_matmul_mps
autogen: _sparse_sparse_matmul.out
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
@ -9833,7 +9832,7 @@
structured_delegate: erfinv.out
variants: method, function
dispatch:
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse
SparseCPU, SparseCUDA: erfinv_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
tags: pointwise
@ -9842,7 +9841,7 @@
structured_delegate: erfinv.out
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_
SparseCPU, SparseCUDA: erfinv_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
tags: pointwise
@ -9852,7 +9851,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: erfinv_out
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out
SparseCPU, SparseCUDA: erfinv_sparse_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
tags: pointwise

View File

@ -10,10 +10,6 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_coalesce_native.h>
#include <ATen/ops/repeat_interleave_native.h>
#include <ATen/ops/cumsum.h>
#include <ATen/ops/_sparse_sparse_matmul_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/add_native.h>
@ -892,114 +888,5 @@ static void sparse_mask_intersection_out_mps_kernel(
/*coalesce_mask=*/false);
}
Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(),
"sparse_sparse_matmul_mps: both inputs must be sparse COO tensors");
TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(),
"sparse_sparse_matmul_mps: both inputs must be on MPS device");
TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2,
"sparse_sparse_matmul_mps: both inputs must be 2D matrices");
TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0,
"sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)");
TORCH_CHECK(mat1_.size(1) == mat2_.size(0),
"mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
"sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(),
" does not match mat2 dtype ", mat2_.scalar_type());
const auto device = mat1_.device();
auto A = mat1_.coalesce();
auto B = mat2_.coalesce();
const auto I = A.size(0);
const auto K = A.size(1);
const auto N = B.size(1);
const auto nnzA = A._nnz();
const auto nnzB = B._nnz();
// Early empty result, return an empty, coalesced tensor
if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) {
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
const auto computeDtype = at::result_type(mat1_, mat2_);
auto A_idx = A._indices().contiguous();
auto A_val = A._values().to(computeDtype).contiguous();
auto A_i = A_idx.select(0, 0).contiguous();
auto A_k = A_idx.select(0, 1).contiguous();
auto B_idx = B._indices().contiguous();
auto B_val = B._values().to(computeDtype).contiguous();
auto B_k = B_idx.select(0, 0).contiguous();
auto B_j = B_idx.select(0, 1).contiguous();
// csr-style row pointers for B by k (the shared dimension)
Tensor row_ptr_B;
{
auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong));
row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong));
build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B);
}
auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K);
auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K);
auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo);
auto counts = deg_B.index_select(0, A_k);
const int64_t P = counts.sum().item<int64_t>();
if (P == 0) {
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
auto group_ids = repeat_interleave_mps(counts);
// exclusive cumsum of counts
auto offsets = cumsum(counts, /*dim=*/0).sub(counts);
auto offsets_gather = offsets.index_select(0, group_ids);
auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather);
// Map each output element to its source B row and position
auto k_per_out = A_k.index_select(0, group_ids);
auto start_in_B = row_ptr_B.index_select(0, k_per_out);
auto seg_index = start_in_B.add(within);
// Assemble candidate coo pairs and values
auto i_out = A_i.index_select(0, group_ids).contiguous();
auto j_out = B_j.index_select(0, seg_index).contiguous();
auto vA_out = A_val.index_select(0, group_ids).contiguous();
auto vB_out = B_val.index_select(0, seg_index).contiguous();
auto v_out = vA_out.mul(vB_out);
// build (2, P) indices
auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous();
out_indices.select(0, 0).copy_(i_out);
out_indices.select(0, 1).copy_(j_out);
auto result = _sparse_coo_tensor_unsafe(
out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype));
result = result.coalesce();
if (result.scalar_type() != mat1_.scalar_type()) {
auto cast_vals = result._values().to(mat1_.scalar_type());
auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options());
out._coalesced_(true);
return out;
}
return result;
}
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
} // namespace at::native

View File

@ -926,14 +926,15 @@ class DeviceCachingAllocator {
(release_cached_blocks() && alloc_block(params, true));
}
if (!block_found) {
const auto& raw_device = c10::xpu::get_raw_device(device);
const auto device_total =
raw_device.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device);
auto device_total = device_prop.global_mem_size;
// Estimate the available device memory when the SYCL runtime does not
// support the corresponding aspect (ext_intel_free_memory).
size_t device_free = device_total -
size_t device_free = device_prop.global_mem_size -
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto& raw_device = c10::xpu::get_raw_device(device);
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
// affected devices.
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
@ -1051,37 +1052,21 @@ class DeviceCachingAllocator {
}
}
std::pair<size_t, size_t> getMemoryInfo() {
const auto& device = c10::xpu::get_raw_device(device_index);
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
device.get_info<sycl::info::device::name>(),
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
const size_t free =
device.get_info<sycl::ext::intel::info::device::free_memory>();
return {free, total};
}
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_total);
static_cast<double>(device_prop.global_mem_size);
}
void setMemoryFraction(double fraction) {
const auto device_total =
xpu::get_raw_device(device_index)
.get_info<sycl::info::device::global_mem_size>();
c10::xpu::DeviceProp device_prop;
c10::xpu::get_device_properties(&device_prop, device_index);
auto device_total = device_prop.global_mem_size;
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
set_fraction = true;
}
@ -1256,8 +1241,7 @@ class XPUAllocator : public DeviceAllocator {
}
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
assertValidDevice(device);
return device_allocators[device]->getMemoryInfo();
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented yet.");
}
double getMemoryFraction(DeviceIndex device) {

View File

@ -382,6 +382,20 @@ coverage_ignore_functions = [
# torch.ao.quantization.backend_config.tensorrt
"get_tensorrt_backend_config",
"get_tensorrt_backend_config_dict",
# torch.ao.quantization.backend_config.utils
"entry_to_pretty_str",
"get_fused_module_classes",
"get_fuser_method_mapping",
"get_fusion_pattern_to_extra_inputs_getter",
"get_fusion_pattern_to_root_node_getter",
"get_module_to_qat_module",
"get_pattern_to_dtype_configs",
"get_pattern_to_input_type_to_index",
"get_qat_module_classes",
"get_root_module_to_quantized_reference_module",
"pattern_to_human_readable",
"remove_boolean_dispatch_from_name",
# torch.ao.quantization.backend_config.x86
"get_x86_backend_config",
# torch.ao.quantization.fuse_modules
"fuse_known_modules",
@ -412,6 +426,25 @@ coverage_ignore_functions = [
"insert_observers_for_model",
"prepare",
"propagate_dtypes_for_known_nodes",
# torch.ao.quantization.fx.utils
"all_node_args_except_first",
"all_node_args_have_no_tensors",
"assert_and_get_unique_device",
"collect_producer_nodes",
"create_getattr_from_value",
"create_node_from_old_node_preserve_meta",
"get_custom_module_class_keys",
"get_linear_prepack_op_for_dtype",
"get_new_attr_name_with_prefix",
"get_non_observable_arg_indexes_and_types",
"get_qconv_prepack_op",
"get_skipped_module_name_and_classes",
"graph_module_from_producer_nodes",
"maybe_get_next_module",
"node_arg_is_bias",
"node_arg_is_weight",
"return_arg_list",
# torch.ao.quantization.pt2e.graph_utils
"bfs_trace_with_node_process",
"find_sequential_partitions",
"get_equivalent_types",
@ -827,10 +860,80 @@ coverage_ignore_functions = [
"get_latency_of_one_partition",
"get_latency_of_partitioned_graph",
"get_partition_to_latency_mapping",
# torch.fx.experimental.proxy_tensor
"decompose",
"disable_autocast_cache",
"disable_proxy_modes_tracing",
"dispatch_trace",
"extract_val",
"fake_signature",
"fetch_sym_proxy",
"fetch_object_proxy",
"get_innermost_proxy_mode",
"get_isolated_graphmodule",
"get_proxy_slot",
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"maybe_handle_decomp",
"proxy_call",
"set_meta",
"set_original_aten_op",
"set_proxy_slot",
"snapshot_fake",
"thunkify",
"track_tensor",
"track_tensor_tree",
"wrap_key",
"wrapper_and_args_for_make_fx",
# torch.fx.experimental.recording
"record_shapeenv_event",
"replay_shape_env_events",
"shape_env_check_state_equal",
# torch.fx.experimental.sym_node
"ceil_impl",
"floor_ceil_helper",
"floor_impl",
"method_to_operator",
"sympy_is_channels_last_contiguous_2d",
"sympy_is_channels_last_contiguous_3d",
"sympy_is_channels_last_strides_2d",
"sympy_is_channels_last_strides_3d",
"sympy_is_channels_last_strides_generic",
"sympy_is_contiguous",
"sympy_is_contiguous_generic",
"to_node",
"wrap_node",
"sym_sqrt",
# torch.fx.experimental.symbolic_shapes
"bind_symbols",
"cast_symbool_to_symint_guardless",
"create_contiguous",
"error",
"eval_guards",
"eval_is_non_overlapping_and_dense",
"expect_true",
"find_symbol_binding_fx_nodes",
"free_symbols",
"free_unbacked_symbols",
"fx_placeholder_targets",
"fx_placeholder_vals",
"guard_bool",
"guard_float",
"guard_int",
"guard_scalar",
"has_hint",
"has_symbolic_sizes_strides",
"is_channels_last_contiguous_2d",
"is_channels_last_contiguous_3d",
"is_channels_last_strides_2d",
"is_channels_last_strides_3d",
"is_contiguous",
"is_non_overlapping_and_dense_indicator",
"is_nested_int",
"is_symbol_binding_fx_node",
"is_symbolic",
# torch.fx.experimental.unification.core
"reify",
# torch.fx.experimental.unification.match
"edge",
@ -868,6 +971,24 @@ coverage_ignore_functions = [
"reverse_dict",
# torch.fx.experimental.unification.multipledispatch.variadic
"isvariadic",
# torch.fx.experimental.unification.unification_tools
"assoc",
"assoc_in",
"dissoc",
"first",
"get_in",
"getter",
"groupby",
"itemfilter",
"itemmap",
"keyfilter",
"keymap",
"merge",
"merge_with",
"update_in",
"valfilter",
"valmap",
# torch.fx.experimental.unification.utils
"freeze",
"hashable",
"raises",
@ -1308,8 +1429,319 @@ coverage_ignore_functions = [
# torch.onnx.symbolic_opset7
"max",
"min",
# torch.onnx.symbolic_opset8
"addmm",
"bmm",
"empty",
"empty_like",
"flatten",
"full",
"full_like",
"gt",
"lt",
"matmul",
"mm",
"ones",
"ones_like",
"prelu",
"repeat",
"zeros",
"zeros_like",
# torch.onnx.symbolic_opset9
"abs",
"acos",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"adaptive_max_pool1d",
"adaptive_max_pool2d",
"adaptive_max_pool3d",
"add",
"addcmul",
"addmm",
"alias",
"amax",
"amin",
"aminmax",
"arange",
"argmax",
"argmin",
"as_strided",
"as_tensor",
"asin",
"atan",
"atan2",
"avg_pool1d",
"avg_pool2d",
"avg_pool3d",
"baddbmm",
"batch_norm",
"bernoulli",
"bitwise_not",
"bitwise_or",
"bmm",
"broadcast_tensors",
"broadcast_to",
"bucketize",
"cat",
"cdist",
"ceil",
"clamp",
"clamp_max",
"clamp_min",
"clone",
"constant_pad_nd",
"contiguous",
"conv1d",
"conv2d",
"conv3d",
"conv_tbc",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
"convert_element_type",
"convolution",
"cos",
"cosine_similarity",
"cross",
"cumsum",
"detach",
"dim",
"div",
"dot",
"dropout",
"elu",
"embedding",
"embedding_bag",
"empty",
"empty_like",
"eq",
"erf",
"exp",
"expand",
"expand_as",
"eye",
"fill",
"flatten",
"floor",
"floor_divide",
"floordiv",
"frobenius_norm",
"full",
"full_like",
"gather",
"ge",
"gelu",
"get_pool_ceil_padding",
"glu",
"group_norm",
"gru",
"gt",
"hann_window",
"hardshrink",
"hardsigmoid",
"hardswish",
"hardtanh",
"index",
"index_add",
"index_copy",
"index_fill",
"index_put",
"index_select",
"instance_norm",
"is_floating_point",
"is_pinned",
"isnan",
"item",
"kl_div",
"layer_norm",
"le",
"leaky_relu",
"lerp",
"lift",
"linalg_cross",
"linalg_matrix_norm",
"linalg_norm",
"linalg_vector_norm",
"linear",
"linspace",
"log",
"log10",
"log1p",
"log2",
"log_sigmoid",
"log_softmax",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"logit",
"logsumexp",
"lstm",
"lstm_cell",
"lt",
"masked_fill",
"masked_fill_",
"matmul",
"max",
"max_pool1d",
"max_pool1d_with_indices",
"max_pool2d",
"max_pool2d_with_indices",
"max_pool3d",
"max_pool3d_with_indices",
"maximum",
"meshgrid",
"min",
"minimum",
"mish",
"mm",
"movedim",
"mse_loss",
"mul",
"multinomial",
"mv",
"narrow",
"native_layer_norm",
"ne",
"neg",
"new_empty",
"new_full",
"new_ones",
"new_zeros",
"nonzero",
"nonzero_numpy",
"noop_complex_operators",
"norm",
"numel",
"numpy_T",
"one_hot",
"ones",
"ones_like",
"onnx_placeholder",
"overload_by_arg_count",
"pad",
"pairwise_distance",
"permute",
"pixel_shuffle",
"pixel_unshuffle",
"pow",
"prelu",
"prim_constant",
"prim_constant_chunk",
"prim_constant_split",
"prim_data",
"prim_device",
"prim_dtype",
"prim_if",
"prim_layout",
"prim_list_construct",
"prim_list_unpack",
"prim_loop",
"prim_max",
"prim_min",
"prim_shape",
"prim_tolist",
"prim_tuple_construct",
"prim_type",
"prim_unchecked_cast",
"prim_uninitialized",
"rand",
"rand_like",
"randint",
"randint_like",
"randn",
"randn_like",
"reciprocal",
"reflection_pad",
"relu",
"relu6",
"remainder",
"repeat",
"repeat_interleave",
"replication_pad",
"reshape",
"reshape_as",
"rnn_relu",
"rnn_tanh",
"roll",
"rrelu",
"rsqrt",
"rsub",
"scalar_tensor",
"scatter",
"scatter_add",
"select",
"selu",
"sigmoid",
"sign",
"silu",
"sin",
"size",
"slice",
"softmax",
"softplus",
"softshrink",
"sort",
"split",
"split_with_sizes",
"sqrt",
"square",
"squeeze",
"stack",
"std",
"std_mean",
"sub",
"t",
"take",
"tan",
"tanh",
"tanhshrink",
"tensor",
"threshold",
"to",
"topk",
"transpose",
"true_divide",
"type_as",
"unbind",
"unfold",
"unsafe_chunk",
"unsafe_split",
"unsafe_split_with_sizes",
"unsqueeze",
"unsupported_complex_operators",
"unused",
"upsample_bilinear2d",
"upsample_linear1d",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"upsample_trilinear3d",
"var",
"var_mean",
"view",
"view_as",
"where",
"wrap_logical_op_with_cast_to",
"wrap_logical_op_with_negation",
"zero",
"zeros",
"zeros_like",
# torch.onnx.utils
"disable_apex_o2_state_dict_hook",
"export",
"export_to_pretty_string",
"exporter_context",
"is_in_onnx_export",
"model_signature",
"register_custom_op_symbolic",
"select_model_mode_for_export",
"setup_onnx_logging",
"unconvertible_ops",
"unpack_quantized_tensor",
"warn_on_static_input_change",
# torch.onnx.verification
"check_export_model_diff",
"verify",
"verify_aten_graph",
@ -1400,6 +1832,32 @@ coverage_ignore_functions = [
"noop_context_fn",
"set_checkpoint_early_stop",
"set_device_states",
# torch.utils.collect_env
"check_release_file",
"get_cachingallocator_config",
"get_clang_version",
"get_cmake_version",
"get_conda_packages",
"get_cpu_info",
"get_cuda_module_loading_config",
"get_cudnn_version",
"get_env_info",
"get_gcc_version",
"get_gpu_info",
"get_libc_version",
"get_lsb_version",
"get_mac_version",
"get_nvidia_driver_version",
"get_nvidia_smi",
"get_os",
"get_pip_packages",
"get_platform",
"get_pretty_env_info",
"get_python_platform",
"get_running_cuda_version",
"get_windows_version",
"is_xnnpack_available",
"pretty_str",
# torch.utils.cpp_backtrace
"get_cpp_backtrace",
# torch.utils.cpp_extension
@ -1463,6 +1921,52 @@ coverage_ignore_functions = [
"apply_shuffle_seed",
"apply_shuffle_settings",
"get_all_graph_pipes",
# torch.utils.flop_counter
"addmm_flop",
"baddbmm_flop",
"bmm_flop",
"conv_backward_flop",
"conv_flop",
"conv_flop_count",
"convert_num_with_suffix",
"get_shape",
"get_suffix_str",
"mm_flop",
"normalize_tuple",
"register_flop_formula",
"sdpa_backward_flop",
"sdpa_backward_flop_count",
"sdpa_flop",
"sdpa_flop_count",
"shape_wrapper",
"transpose_shape",
# torch.utils.hipify.hipify_python
"add_dim3",
"compute_stats",
"extract_arguments",
"file_add_header",
"file_specific_replacement",
"find_bracket_group",
"find_closure_group",
"find_parentheses_group",
"fix_static_global_kernels",
"get_hip_file_path",
"hip_header_magic",
"hipify",
"is_caffe2_gpu_file",
"is_cusparse_file",
"is_out_of_place",
"is_pytorch_file",
"is_special_file",
"match_extensions",
"matched_files_iter",
"openf",
"preprocess_file_and_save_result",
"preprocessor",
"processKernelLaunches",
"replace_extern_shared",
"replace_math_functions",
"str2bool",
# torch.utils.hooks
"unserializable_hook",
"warn_if_has_hooks",

View File

@ -12,37 +12,6 @@ These APIs are experimental and subject to change without notice.
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
```
## torch.fx.experimental.sym_node
```{eval-rst}
.. currentmodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. automodule:: torch.fx.experimental.sym_node
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
is_channels_last_contiguous_2d
is_channels_last_contiguous_3d
is_channels_last_strides_2d
is_channels_last_strides_3d
is_contiguous
is_non_overlapping_and_dense_indicator
method_to_operator
sympy_is_channels_last_contiguous_2d
sympy_is_channels_last_contiguous_3d
sympy_is_channels_last_strides_2d
sympy_is_channels_last_strides_3d
sympy_is_channels_last_strides_generic
sympy_is_contiguous
sympy_is_contiguous_generic
```
## torch.fx.experimental.symbolic_shapes
```{eval-rst}
@ -100,25 +69,6 @@ These APIs are experimental and subject to change without notice.
rebind_unbacked
resolve_unbacked_bindings
is_accessor_node
cast_symbool_to_symint_guardless
create_contiguous
error
eval_guards
eval_is_non_overlapping_and_dense
find_symbol_binding_fx_nodes
free_symbols
free_unbacked_symbols
fx_placeholder_targets
fx_placeholder_vals
guard_bool
guard_float
guard_int
guard_scalar
has_hint
has_symbolic_sizes_strides
is_nested_int
is_symbol_binding_fx_node
is_symbolic
```
## torch.fx.experimental.proxy_tensor
@ -141,46 +91,4 @@ These APIs are experimental and subject to change without notice.
get_proxy_mode
maybe_enable_thunkify
maybe_disable_thunkify
decompose
disable_autocast_cache
disable_proxy_modes_tracing
extract_val
fake_signature
fetch_object_proxy
fetch_sym_proxy
has_proxy_slot
is_sym_node
maybe_handle_decomp
proxy_call
set_meta
set_original_aten_op
set_proxy_slot
snapshot_fake
```
## torch.fx.experimental.unification.unification_tools
```{eval-rst}
.. currentmodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. automodule:: torch.fx.experimental.unification.unification_tools
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
assoc
assoc_in
dissoc
first
keyfilter
keymap
merge
merge_with
update_in
valfilter
valmap

View File

@ -1134,6 +1134,7 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.refinement_types
.. py:module:: torch.fx.experimental.rewriter
.. py:module:: torch.fx.experimental.schema_type_annotation
.. py:module:: torch.fx.experimental.sym_node
.. py:module:: torch.fx.experimental.unification.core
.. py:module:: torch.fx.experimental.unification.dispatch
.. py:module:: torch.fx.experimental.unification.match
@ -1143,6 +1144,7 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher
.. py:module:: torch.fx.experimental.unification.multipledispatch.utils
.. py:module:: torch.fx.experimental.unification.multipledispatch.variadic
.. py:module:: torch.fx.experimental.unification.unification_tools
.. py:module:: torch.fx.experimental.unification.utils
.. py:module:: torch.fx.experimental.unification.variable
.. py:module:: torch.fx.experimental.unify_refinements

View File

@ -134,23 +134,6 @@ Quantization to work with this as well.
ObservationType
```
## torch.ao.quantization.backend_config.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.backend_config.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
entry_to_pretty_str
pattern_to_human_readable
remove_boolean_dispatch_from_name
```
## torch.ao.quantization.fx.custom_config
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
@ -171,30 +154,6 @@ This module contains a few CustomConfig classes that's used in both eager mode a
StandaloneModuleConfigEntry
```
## torch.ao.quantization.fx.utils
```{eval-rst}
.. currentmodule:: torch.ao.quantization.fx.utils
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
all_node_args_except_first
all_node_args_have_no_tensors
collect_producer_nodes
create_getattr_from_value
create_node_from_old_node_preserve_meta
graph_module_from_producer_nodes
maybe_get_next_module
node_arg_is_bias
node_arg_is_weight
return_arg_list
```
## torch.ao.quantization.quantizer
```{eval-rst}

View File

@ -19,91 +19,6 @@
swap_tensors
```
# torch.utils.collect_env
```{eval-rst}
.. automodule:: torch.utils.collect_env
```
```{eval-rst}
.. currentmodule:: torch.utils.collect_env
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
check_release_file
is_xnnpack_available
pretty_str
```
# torch.utils.flop_counter
```{eval-rst}
.. automodule:: torch.utils.flop_counter
```
```{eval-rst}
.. currentmodule:: torch.utils.flop_counter
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
baddbmm_flop
bmm_flop
conv_backward_flop
conv_flop
conv_flop_count
register_flop_formula
sdpa_backward_flop
sdpa_backward_flop_count
sdpa_flop
sdpa_flop_count
shape_wrapper
```
# torch.utils.hipify.hipify_python
```{eval-rst}
.. automodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. currentmodule:: torch.utils.hipify.hipify_python
```
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
compute_stats
extract_arguments
file_add_header
file_specific_replacement
find_bracket_group
find_closure_group
find_parentheses_group
fix_static_global_kernels
hip_header_magic
hipify
is_caffe2_gpu_file
is_cusparse_file
is_out_of_place
is_pytorch_file
is_special_file
openf
preprocess_file_and_save_result
preprocessor
processKernelLaunches
replace_extern_shared
replace_math_functions
str2bool
```
<!-- This module needs to be documented. Adding here in the meantime
for tracking purposes -->
```{eval-rst}
@ -128,6 +43,7 @@ for tracking purposes -->
.. py:module:: torch.utils.benchmark.utils.valgrind_wrapper.timer_interface
.. py:module:: torch.utils.bundled_inputs
.. py:module:: torch.utils.checkpoint
.. py:module:: torch.utils.collect_env
.. py:module:: torch.utils.cpp_backtrace
.. py:module:: torch.utils.cpp_extension
.. py:module:: torch.utils.data.backward_compatibility
@ -164,8 +80,10 @@ for tracking purposes -->
.. py:module:: torch.utils.data.sampler
.. py:module:: torch.utils.dlpack
.. py:module:: torch.utils.file_baton
.. py:module:: torch.utils.flop_counter
.. py:module:: torch.utils.hipify.constants
.. py:module:: torch.utils.hipify.cuda_to_hip_mappings
.. py:module:: torch.utils.hipify.hipify_python
.. py:module:: torch.utils.hipify.version
.. py:module:: torch.utils.hooks
.. py:module:: torch.utils.jit.log_extract

View File

@ -204,16 +204,14 @@ class DistConvolutionOpsTest(DTensorTestBase):
self.assertTrue(b_dt.grad is not None)
self.assertTrue(x_dt.grad is None)
def _run_single_arg_fwd(
self, model, arg, placements=None
) -> tuple[torch.Tensor, torch.Tensor]:
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
device_mesh = self.build_device_mesh()
model_copy = copy.deepcopy(model).to(device=self.device_type)
dist_model = distribute_module(model, device_mesh, _conv_fn)
arg_dt = DTensor.from_local(arg, device_mesh, placements)
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
out_dt = dist_model(arg_dt.to(device=self.device_type))
out = model_copy(arg_dt.full_tensor())
out = model_copy(arg)
return (out_dt.full_tensor(), out)
@with_comms
@ -221,20 +219,22 @@ class DistConvolutionOpsTest(DTensorTestBase):
model = nn.Conv1d(64, 64, 3, padding=1)
x = torch.randn(1, 64, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt, out)
self.assertEqual(out_dt.shape, out.shape)
@with_comms
def test_conv3d(self):
model = nn.Conv3d(64, 64, 3, padding=1)
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)])
self.assertEqual(out_dt, out)
out_dt, out = self._run_single_arg_fwd(model, x)
self.assertEqual(out_dt.shape, out.shape)
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
DistConvolutionOpsTest,
# Send / recv ops are not supported
skipped_tests=[
"test_conv1d",
"test_conv3d",
"test_conv_backward_none_grad_inp",
"test_depthwise_convolution",
"test_downsampling_convolution",

View File

@ -2,6 +2,7 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import copy
import itertools
import unittest
@ -21,8 +22,9 @@ from torch.distributed.tensor import (
)
from torch.distributed.tensor._collective_utils import shard_dim_alltoall
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -33,11 +35,7 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
DTensorTestBase,
generate_shard_orders,
make_full_tensor,
map_local_tensor_for_rank,
patched_distribute_tensor as _distribute_tensor,
redistribute,
with_comms,
)
from torch.utils._debug_mode import DebugMode
@ -787,6 +785,88 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
else:
return ""
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def redistribute(
self,
dtensor_input,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""
wrapper function to support shard_order for redistribution
This is a simpler version of Redistribute, only considers the forward.
"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
old_spec = dtensor_input._spec
new_spec = copy.deepcopy(old_spec)
new_spec.placements = placements
if shard_order is not None:
new_spec.shard_order = shard_order
else:
new_spec.shard_order = ()
if old_spec == new_spec:
return dtensor_input
dtensor_input = DTensor.from_local(
redistribute_local_tensor(
dtensor_input.to_local(),
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)
return dtensor_input # returns DTensor
# TODO(zpcore): remove once the native distribute_tensor supports
# shard_order arg
def distribute_tensor(
self,
input_tensor,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""wrapper function to support shard_order for tensor distribution"""
if placements is None:
placements = self._shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
# fix the shard order
return self.redistribute(
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def full_tensor(self, dtensor_input):
"""wrapper function to support DTensor.full_tensor"""
return self.redistribute(
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
).to_local()
def _shard_order_to_placement(self, shard_order, mesh):
"""convert shard_order to placement with only Replicate() and Shard()"""
placements = [Replicate() for _ in range(mesh.ndim)]
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
def _convert_shard_order_dict_to_ShardOrder(self, shard_order):
"""Convert shard_order dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
for tensor_dim, mesh_dims in shard_order.items()
)
@with_comms
def test_ordered_redistribute(self):
"""Test ordered redistribution with various sharding syntaxes"""
@ -847,11 +927,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
sharding_src_dst_pairs_with_expected_trace
):
sharded_dt = _distribute_tensor(
sharded_dt = self.distribute_tensor(
input_data.clone(), mesh, src_placement, shard_order=src_order
)
with DebugMode(record_torchfunction=False) as debug_mode:
sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order)
sharded_dt = self.redistribute(
sharded_dt, mesh, dst_placement, dst_order
)
trace_str = self._extract_redistribute_trace_from_debug_mode(
debug_mode.debug_string()
)
@ -875,11 +957,49 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
trace_str,
"""S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""",
)
expected_dt = _distribute_tensor(
expected_dt = self.distribute_tensor(
input_data.clone(), mesh, dst_placement, shard_order=dst_order
)
self.assertEqual(sharded_dt.to_local(), expected_dt.to_local())
def generate_shard_orders(self, mesh, tensor_rank):
# Generate all possible sharding placement of tensor with rank
# `tensor_rank` over mesh.
def _split_list(lst: list, N: int):
def compositions(n, k):
if k == 1:
yield [n]
else:
for i in range(1, n - k + 2):
for tail in compositions(n - i, k - 1):
yield [i] + tail
length = len(lst)
for comp in compositions(length, N):
result = []
start = 0
for size in comp:
result.append(lst[start : start + size])
start += size
yield result
all_mesh = list(range(mesh.ndim))
all_device_order = list(itertools.permutations(all_mesh))
for device_order in all_device_order:
# split on device orders, and assign each device order segment to a tensor dim
for num_split in range(1, mesh.ndim + 1):
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
for tensor_dims in itertools.combinations(
range(tensor_rank), len(splitted_list)
):
shard_order = {}
assert len(tensor_dims) == len(splitted_list)
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
shard_order[tensor_dim] = device_order[
mesh_dims[0] : mesh_dims[-1] + 1
]
yield self._convert_shard_order_dict_to_ShardOrder(shard_order)
@with_comms
def test_generate_shard_orders(self):
"""Check if `generate_shard_orders` generates unique sharding combinations"""
@ -892,7 +1012,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
]
for test_input in test_inputs:
all_combinations = []
for shard_order in generate_shard_orders(
for shard_order in self.generate_shard_orders(
test_input["mesh"], test_input["tensor_rank"]
):
all_combinations.append(shard_order) # noqa: PERF402
@ -942,12 +1062,12 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
input_data = torch.randn(tensor_shape, device=self.device_type)
tensor_rank = input_data.ndim
with maybe_disable_local_tensor_mode():
shard_orders = generate_shard_orders(mesh, tensor_rank)
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
sharded_dt = _distribute_tensor(
sharded_dt = self.distribute_tensor(
input_data.clone(), mesh, placements=None, shard_order=shard_order
)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
# 2. Verify the correctness of redistribution from DTensor to DTensor.
# This test repeatedly redistributes a DTensor to various ordered
@ -958,20 +1078,20 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
tensor_rank = input_data.ndim
prev_sharded_dt = None
with maybe_disable_local_tensor_mode():
shard_orders = generate_shard_orders(mesh, tensor_rank)
shard_orders = self.generate_shard_orders(mesh, tensor_rank)
for shard_order in shard_orders:
if prev_sharded_dt is None:
prev_sharded_dt = _distribute_tensor(
prev_sharded_dt = self.distribute_tensor(
input_data.clone(),
mesh,
placements=None,
shard_order=shard_order,
)
else:
sharded_dt = redistribute(
sharded_dt = self.redistribute(
prev_sharded_dt, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(make_full_tensor(sharded_dt), input_data)
self.assertEqual(self.full_tensor(sharded_dt), input_data)
prev_sharded_dt = sharded_dt
@with_comms
@ -1016,13 +1136,13 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
local_tensor = torch.randn(shape, device=self.device_type)
full_tensor = DTensor.from_local(local_tensor, mesh, placements)
with maybe_disable_local_tensor_mode():
shard_orders = generate_shard_orders(mesh, len(shape))
shard_orders = self.generate_shard_orders(mesh, len(shape))
for shard_order in shard_orders:
sharded_dt = redistribute(
sharded_dt = self.redistribute(
full_tensor, mesh, placements=None, shard_order=shard_order
)
self.assertEqual(
make_full_tensor(sharded_dt), make_full_tensor(full_tensor)
self.full_tensor(sharded_dt), self.full_tensor(full_tensor)
)
@unittest.skip(
@ -1032,20 +1152,24 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
@with_comms
def test_ordered_redistribute_for_special_placement(self):
"""Test ordered redistribution with special placement"""
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
torch.manual_seed(21)
mesh = init_device_mesh(self.device_type, (8,))
input_data = torch.randn((8, 8), device=self.device_type)
src_placement = [Shard(1)]
tgt_placement = [
(MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
(_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),)
]
sharded_dt = _distribute_tensor(
sharded_dt = self.distribute_tensor(
input_data.clone(),
mesh,
src_placement,
shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),),
)
sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None)
sharded_dt = self.redistribute(
sharded_dt, mesh, tgt_placement, shard_order=None
)
@with_comms
def test_shard_order_same_data_as_strided_shard(self):
@ -1055,7 +1179,7 @@ class DistributeWithDeviceOrderTest(DTensorTestBase):
strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)]
x_strided_dt = distribute_tensor(x, device_mesh, strided_placement)
# specify right-to-left order use ordered shard
x_ordered_dt = _distribute_tensor(
x_ordered_dt = self.distribute_tensor(
x,
device_mesh,
placements=[Shard(0), Shard(0)],

View File

@ -34,10 +34,6 @@ from torch.distributed.tensor.placement_types import (
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
generate_shard_orders,
LocalDTensorTestBase,
patched_distribute_tensor as _distribute_tensor,
shard_order_to_placement,
with_comms,
)
@ -778,63 +774,6 @@ class TestStridedSharding(DTensorTestBase):
self.assertEqual(dtensor.full_tensor(), tensor)
class Test_StridedShard_with_shard_order(LocalDTensorTestBase):
@property
def world_size(self) -> int:
return 32
@with_comms
def test_StridedShard_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2))
shard_iter = generate_shard_orders(mesh, 3)
# It takes ~4.8h to complete total 2520 shard order combinations here
# using LocalTensor. So we only randomly pick 25 shard orders to test.
all_shard_order = list(shard_iter)
import random
random.seed(42)
shard_order_choices = random.sample(
all_shard_order, min(25, len(all_shard_order))
)
x = torch.randn(32, 32, 32)
for shard_order in shard_order_choices:
a = _distribute_tensor(x, mesh, None, shard_order)
placement_without_stridedshard = shard_order_to_placement(
shard_order, mesh
)
placements_with_stridedshard = (
DTensorSpec._convert_shard_order_to_StridedShard(
shard_order, placement_without_stridedshard, mesh
)
)
b = distribute_tensor(x, mesh, placements_with_stridedshard)
shard_order_from_stridedshard = (
DTensorSpec._maybe_convert_StridedShard_to_shard_order(
placements_with_stridedshard, mesh
)
)
self.assertEqual(shard_order, shard_order_from_stridedshard)
self.assertEqual(a.to_local(), b.to_local())
@with_comms
def test_StridedShard_not_convertible_to_shard_order(self):
with LocalTensorMode(ranks=self.world_size):
mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8))
unconvertible_placements_list = [
[_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)],
[_StridedShard(0, split_factor=2), Shard(1)],
[_StridedShard(1, split_factor=16), Shard(1)],
]
for placements in unconvertible_placements_list:
shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order(
tuple(placements), mesh
)
self.assertIsNone(shard_order)
class Test2DStridedLocalShard(DTensorTestBase):
@property
def world_size(self):

View File

@ -1,236 +1,245 @@
{
"EndToEndLSTM (__main__.RNNTest)": 190.48799641927084,
"MultiheadAttention (__main__.ModulesTest)": 141.2663370768229,
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 82.87333234151204,
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 70.6538565499442,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 123.34033711751302,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 171.25450134277344,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 119.71899922688802,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.35733322870163,
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.64533233642578,
"test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.672952016194664,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 138.04000091552734,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 172.1344985961914,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 114.02050018310547,
"test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.25642830984933,
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.3350003560384,
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 120.95249938964844,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.97774887084961,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.90774917602539,
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1144.3935089111328,
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 222.58500061035156,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.10033162434894,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 517.1875050862631,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 113.88125228881836,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 235.77350616455078,
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.6155014038086,
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.63325119018555,
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 216.2968317667643,
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 153.0915012359619,
"test_cat_2k_args (__main__.TestTEFuserDynamic)": 108.80471753561869,
"test_cat_2k_args (__main__.TestTEFuserStatic)": 102.20949847949669,
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 311.7026621500651,
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 395.0001729329427,
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 348.6218566894531,
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 98.71574974060059,
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 97.68499946594238,
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.0557508468628,
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 65.86899948120117,
"test_comprehensive_gradient_cuda_complex64 (__main__.TestDecompCUDA)": 97.15880012512207,
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 103.20700073242188,
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 102.74033610026042,
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 460.4286702473958,
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 435.62066650390625,
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 287.3090057373047,
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 265.1860008239746,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1235.7365112304688,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.20825004577637,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1281.2615051269531,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.90750026702881,
"test_comprehensive_linalg_householder_product_cuda_complex64 (__main__.TestDecompCUDA)": 79.04633331298828,
"test_comprehensive_linalg_lu_factor_ex_cuda_complex128 (__main__.TestDecompCUDA)": 68.10879821777344,
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.43025207519531,
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.94575023651123,
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.93649864196777,
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.46275043487549,
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 64.10650062561035,
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.03124904632568,
"test_comprehensive_linalg_svd_cuda_float64 (__main__.TestDecompCUDA)": 64.32800025939942,
"test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 96.41353665865384,
"test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 100.17661388103778,
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 110.95025062561035,
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 108.06550025939941,
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 104.24150085449219,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.453749656677246,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 61.739999771118164,
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 69.96549987792969,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 113.65749931335449,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 106.57500076293945,
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 117.54049682617188,
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 116.19766489664714,
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 272.48475646972656,
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 248.12175369262695,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 79.66900062561035,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 81.52649879455566,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 79.29400062561035,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.40349960327148,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 128.42924880981445,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 125.03675079345703,
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1264.9732360839844,
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1250.7332458496094,
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1255.0684814453125,
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 574.4627532958984,
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 581.7282485961914,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 65.052001953125,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 61.19200134277344,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 63.16874885559082,
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 62.39250183105469,
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 113.32574844360352,
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 113.91499900817871,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 74.42549800872803,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 76.1560001373291,
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.76750087738037,
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 70.69724941253662,
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 69.87625026702881,
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 80.2542495727539,
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 69.0419979095459,
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 117.03342655726841,
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 289.50213841029574,
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 67.38800048828125,
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 145.27399444580078,
"test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 66.9245999654134,
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 151.91099548339844,
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 92.79549789428711,
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.60149955749512,
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 69.27724676392972,
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 76.24971498761859,
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 81.93449974060059,
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.87700080871582,
"test_count_nonzero_all (__main__.TestBool)": 631.2585144042969,
"test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 61.042999267578125,
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.49850082397461,
"test_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 93.03299713134766,
"test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 228.46711820714614,
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 286.29998779296875,
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 68.43842806134906,
"test_fail_random.py (__main__.TestTyping)": 74.83523060725285,
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 72.84900093078613,
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 75.86675071716309,
"test_fuse_large_params_cpu (__main__.CpuTests)": 151.4199981689453,
"test_fuse_large_params_cuda (__main__.GPUTests)": 60.351999282836914,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 158.3622828892299,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 149.6796646118164,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.97800064086914,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 114.8385009765625,
"test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 84.69736822027909,
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.62700080871582,
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 89.197998046875,
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 96.46900177001953,
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 187.83824920654297,
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.49449920654297,
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 124.90424919128418,
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 518.4157485961914,
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 304.6440022786458,
"test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 143.82052836698645,
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 77.4985705784389,
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 76.06225109100342,
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 138.9222858973912,
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 120.62233225504558,
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 148.1219940185547,
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 109.34200286865234,
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 119.36233266194661,
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 127.95700073242188,
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 61.64850175380707,
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 105.3174296787807,
"test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 585.9210001627604,
"test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 504.3250020345052,
"test_lstm_cpu (__main__.TestMkldnnCPU)": 86.21566645304362,
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 129.277715410505,
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 64.24800109863281,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 77.23899841308594,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_True (__main__.TestCKBackend)": 65.15649795532227,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 62.579833984375,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.6555004119873,
"test_pattern_matcher_multi_user_cpu (__main__.CpuTritonTests)": 142.21566772460938,
"test_proper_exit (__main__.TestDataLoader)": 267.74214717320035,
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 266.6539971487863,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 101.97100067138672,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.3346659342448,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.50300216674805,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 104.61333465576172,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.41133371988933,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.37100219726562,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.30900065104167,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.61750030517578,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 79.33600234985352,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 101.2393315633138,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 103.18400192260742,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.4114990234375,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.52833302815755,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.72700119018555,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 100.61966705322266,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.2750015258789,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.17449951171875,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.96749877929688,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 106.44049835205078,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 101.7173334757487,
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 531.5236612955729,
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1077.4210205078125,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 812.0880126953125,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1347.9365234375,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 88.93533070882161,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 269.01949310302734,
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.99799601236978,
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 232.36275100708008,
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 69.80400085449219,
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 134.3415012359619,
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.51749992370605,
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 91.21066792805989,
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 170.97775268554688,
"test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.608266321818036,
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.62575149536133,
"test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 63.59499969482422,
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 88.68299865722656,
"test_rnn_decomp_module_nn_LSTM_train_mode_cuda_float32 (__main__.TestDecompCUDA)": 91.50320053100586,
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 66.10774898529053,
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 66.20533180236816,
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 243.1092529296875,
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 105.01200103759766,
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.93685695103237,
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 142.38899993896484,
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 119.90166600545247,
"test_sort_bool_cpu (__main__.CpuTritonTests)": 346.2856750488281,
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 423.09974098205566,
"test_sort_stable_cuda (__main__.GPUTests)": 117.61659927368164,
"test_sort_transpose_cpu (__main__.CpuTritonTests)": 378.31200154622394,
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 222.822007894516,
"test_terminate_handler_on_crash (__main__.TestTorch)": 143.31728431156702,
"test_terminate_signal (__main__.ForkTest)": 168.20485967184817,
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 168.19242484867573,
"test_terminate_signal (__main__.SpawnTest)": 172.16428443363733,
"test_thnn_conv_strided_padded_dilated (__main__.TestConvolutionNN)": 93.30639710426331,
"test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 163.89743041992188,
"test_train_parity_with_activation_checkpointing (__main__.TestFullyShard1DTrainingCompose)": 60.47671399797712,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 63.39550018310547,
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 173.53924942016602,
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 175.3212537765503,
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.20649909973145,
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 99.9885025024414,
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 71.64024829864502,
"test_view_ops (__main__.TestViewOpsWithLocalTensor)": 73.45887422561646,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.75249862670898,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 61.858001708984375,
"test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 65.11023766653878,
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 66.35274982452393,
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 61.196499824523926,
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.75380906604585,
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 73.64649868011475,
"test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 75.09799966358003,
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.51450157165527,
"test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.21433276221866,
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 73.20024871826172,
"test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 88.1349983215332,
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.89924907684326,
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.32975196838379,
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 120.09600067138672
"EndToEndLSTM (__main__.RNNTest)": 207.89400227864584,
"MultiheadAttention (__main__.ModulesTest)": 141.1396687825521,
"test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 214.02366638183594,
"test__adaptive_avg_pool2d (__main__.CPUReproTests)": 77.26125049591064,
"test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 116.37000020345052,
"test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 69.25722334120009,
"test_after_aot_gpu_runtime_error (__main__.MinifierIsolateTests)": 65.84466807047527,
"test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 178.41399637858072,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.55014337812151,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 122.18047623407273,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 192.6405719575428,
"test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.27904801141648,
"test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.906999588012695,
"test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 62.244998931884766,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 150.04100036621094,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 191.85050201416016,
"test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.9276631673177,
"test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.31450271606445,
"test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 125.24066416422527,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.47783279418945,
"test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.46250025431316,
"test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1031.0534973144531,
"test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 239.67400105794272,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 495.0447726779514,
"test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 490.18524169921875,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 144.06477737426758,
"test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 342.20416259765625,
"test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 62.01366678873698,
"test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 71.07200050354004,
"test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 73.9221674601237,
"test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 226.0122528076172,
"test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 144.97249857584634,
"test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 303.20537185668945,
"test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 386.0518798828125,
"test_collect_callgrind (__main__.TestBenchmarkUtils)": 291.2442270914714,
"test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 95.87866719563802,
"test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 98.38716634114583,
"test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 69.08016649881999,
"test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 69.88233311971028,
"test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 104.17599995930989,
"test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 97.41800308227539,
"test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 474.6719970703125,
"test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 440.4375,
"test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 293.3983332316081,
"test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 238.7328338623047,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1218.4906717936199,
"test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.73516782124837,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1156.0123494466145,
"test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.13916714986165,
"test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.90450032552083,
"test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.42100016276042,
"test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.98883310953777,
"test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 73.34433364868164,
"test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.38016573588053,
"test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.52783330281575,
"test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 111.06333287556966,
"test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 110.19833374023438,
"test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.10083134969075,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.23766644795736,
"test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 70.18666712443034,
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 62.61399841308594,
"test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 67.7816670735677,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 121.6183344523112,
"test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 107.30266698201497,
"test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 130.8143310546875,
"test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 127.27633412679036,
"test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 303.55183664957684,
"test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 234.41216532389322,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 85.3436673482259,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 80.9688326517741,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 82.55149968465169,
"test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.37966791788737,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 129.88233184814453,
"test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 129.4015007019043,
"test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1282.3826497395833,
"test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1270.64599609375,
"test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1297.9046630859375,
"test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 545.2034962972006,
"test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 572.5616760253906,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 64.40316645304362,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.68383344014485,
"test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.48333422342936,
"test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 61.959999084472656,
"test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 105.79100036621094,
"test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 122.34666570027669,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 68.7205015818278,
"test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.2183329264323,
"test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.86883227030437,
"test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 77.48183314005534,
"test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 79.1564998626709,
"test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 160.41250228881836,
"test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 79.10633341471355,
"test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 60.106833140055336,
"test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 221.3586196899414,
"test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 504.3203754425049,
"test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 78.03233337402344,
"test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 152.302001953125,
"test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 152.99433390299478,
"test_conv_bn_fuse_cpu (__main__.CpuTests)": 96.25399971008301,
"test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 75.70275068283081,
"test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 139.14399747674665,
"test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 72.7847490310669,
"test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 91.59966786702473,
"test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 87.57833353678386,
"test_count_nonzero_all (__main__.TestBool)": 664.9986343383789,
"test_cp_flex_attention_document_mask (__main__.CPFlexAttentionTest)": 78.31500244140625,
"test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 385.24249792099,
"test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.70466740926106,
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 685.0679931640625,
"test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 86.26266733805339,
"test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 292.93699645996094,
"test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 66.84199905395508,
"test_fail_arithmetic_ops.py (__main__.TestTyping)": 69.56212568283081,
"test_fail_creation_ops.py (__main__.TestTyping)": 69.80560022989908,
"test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 73.36666552225749,
"test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 90.40366744995117,
"test_fuse_large_params_cpu (__main__.CpuTests)": 132.73199844360352,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 150.16662406921387,
"test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 159.28499794006348,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 165.19283294677734,
"test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 151.12366739908853,
"test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.61699930826823,
"test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.00600179036458,
"test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.3759994506836,
"test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 190.89249674479166,
"test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 149.6598358154297,
"test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 146.07766723632812,
"test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 532.8139902750651,
"test_graph_partition_refcount_cuda (__main__.GPUTests)": 69.78400001525878,
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 267.04988850487604,
"test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 273.54955800374347,
"test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.84733072916666,
"test_indirect_device_assert (__main__.TritonCodeGenTests)": 326.0143330891927,
"test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.96037435531616,
"test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 77.44933319091797,
"test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 126.81488884819879,
"test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 118.70199839274089,
"test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 129.20266723632812,
"test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 97.18800099690755,
"test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 130.3183339436849,
"test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 140.43233235677084,
"test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 293.122774971856,
"test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 63.835832277933754,
"test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.77049922943115,
"test_lstm_cpu (__main__.TestMkldnnCPU)": 100.89649963378906,
"test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 140.07424926757812,
"test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 72.90299733479817,
"test_max_autotune_addmm_search_space_EXHAUSTIVE_dynamic_True (__main__.TestMaxAutotuneSubproc)": 82.62433369954427,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 87.51499938964844,
"test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_True_use_aoti_True (__main__.TestCKBackend)": 71.22416591644287,
"test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 424.50966389973956,
"test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 134.14600626627603,
"test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 358.88099161783856,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.58866712782118,
"test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 62.68674945831299,
"test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 65.85794713936355,
"test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTest)": 103.6923344930013,
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTest)": 187.6953328450521,
"test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 370.27442932128906,
"test_proper_exit (__main__.TestDataLoader)": 227.83111148410373,
"test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 227.1901126437717,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.52099990844727,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.50249862670898,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 92.52400207519531,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.75499725341797,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 107.40500259399414,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 83.80450057983398,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 107.46599833170573,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.65650177001953,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 83.4114990234375,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.47100067138672,
"test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 108.55533345540364,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 89.23666381835938,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.13900375366211,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 100.14550018310547,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.33649826049805,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.08150100708008,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 97.59600067138672,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 104.82933553059895,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 114.43099721272786,
"test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 110.40333302815755,
"test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 567.2765197753906,
"test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1032.5083312988281,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 852.7170003255209,
"test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1361.954854329427,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 77.385498046875,
"test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 265.0193354288737,
"test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 115.31749725341797,
"test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 245.27666727701822,
"test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 71.75300216674805,
"test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 141.8895009358724,
"test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 71.15749994913737,
"test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 90.59066772460938,
"test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 173.73916625976562,
"test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.65066655476888,
"test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 99.21799850463867,
"test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 90.86299896240234,
"test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.57050196329753,
"test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 69.65149958928426,
"test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 78.13350168863933,
"test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 76.85255601671007,
"test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 333.04866282145184,
"test_save_load_large_string_attribute (__main__.TestSaveLoad)": 146.96599833170572,
"test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 160.4881100124783,
"test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 124.10055626763238,
"test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 117.38410907321506,
"test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 710.2327779134115,
"test_sort_stable_cpu (__main__.CpuTritonTests)": 1324.4399820963542,
"test_sort_stable_cuda (__main__.GPUTests)": 76.83109970092774,
"test_split_cumsum_cpu (__main__.CpuTritonTests)": 88.58433532714844,
"test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 160.1271684964498,
"test_tensor_split (__main__.TestVmapOperators)": 79.18955569393519,
"test_terminate_handler_on_crash (__main__.TestTorch)": 111.30388899644215,
"test_terminate_signal (__main__.ForkTest)": 132.3458870516883,
"test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 132.2043343567186,
"test_terminate_signal (__main__.SpawnTest)": 136.1005539894104,
"test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 76.20899939537048,
"test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 63.82099969046457,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 61.925000508626304,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 60.89849980672201,
"test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 66.88233375549316,
"test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.9854990641276,
"test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.4044977823893,
"test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 108.19166437784831,
"test_unary_ops (__main__.TestTEFuserDynamic)": 96.32655514611139,
"test_unary_ops (__main__.TestTEFuserStatic)": 105.33362591266632,
"test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 97.8336664835612,
"test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 82.86566925048828,
"test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 68.26500002543132,
"test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 97.1120007832845,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 88.24766794840495,
"test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 65.41266759236653,
"test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 74.75533294677734,
"test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 73.52500089009602,
"test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.85466639200847,
"test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 98.39650090535481,
"test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 61.39695285615467,
"test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 77.88249842325847,
"test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.0695006052653,
"test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 81.86250114440918,
"test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 98.63116455078125,
"test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 94.85683314005534,
"test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 173.00183614095053
}

View File

@ -239,12 +239,6 @@ class TestAccelerator(TestCase):
self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated)
self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved)
@unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!")
def test_get_memory_info(self):
free_bytes, total_bytes = torch.accelerator.get_memory_info()
self.assertGreaterEqual(free_bytes, 0)
self.assertGreaterEqual(total_bytes, 0)
if __name__ == "__main__":
run_tests()

View File

@ -3728,6 +3728,7 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(*floating_and_complex_types())
@dtypesIfMPS(*all_mps_types())
@expectedFailureMPS
@dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater and not TEST_WITH_ROCM else [],
*[torch.bfloat16] if SM80OrLater and not TEST_WITH_ROCM else [],
torch.complex64,
@ -3824,9 +3825,9 @@ class TestSparse(TestSparseBase):
def different_dtypes():
a, i_a, v_a = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
b, i_b, v_b = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced)
r2 = torch.sparse.mm(a.to(torch.float32), a.to(torch.float16))
r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32))
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Float does not match mat2 dtype Half', different_dtypes)
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
def test_backward_noncontiguous():
# Sparse.mm backward used to wrong with non-contiguous grads,

View File

@ -195,12 +195,10 @@ def get_new_attr_name_with_prefix(prefix: str) -> Callable:
def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
r"""Starting from a target node, trace back until we hit input or
getattr node. This is used to extract the chain of operators
starting from getattr to the target node, for example::
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
starting from getattr to the target node, for example
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
collect_producer_nodes(observed) will either return a list of nodes that
produces the observed node or None if we can't extract a self contained
graph without free variables(inputs of the forward function).

View File

@ -386,8 +386,23 @@ static void bindGetDeviceProperties(PyObject* module) {
static void initXpuMethodBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) {
py::gil_scoped_release no_gil;
return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index);
#if SYCL_COMPILER_VERSION >= 20250000
auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size;
auto& device = c10::xpu::get_raw_device(device_index);
TORCH_CHECK(
device.has(sycl::aspect::ext_intel_free_memory),
"The device (",
at::xpu::getDeviceProperties(device_index)->name,
") doesn't support querying the available free memory. ",
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
"to help us prioritize its implementation.");
auto free = device.get_info<sycl::ext::intel::info::device::free_memory>();
return std::make_tuple(free, total);
#else
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer.");
#endif
});
m.def(
"_xpu_getStreamFromExternal",

View File

@ -1,5 +1,4 @@
import itertools
import math
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, cast, NamedTuple, Optional
@ -8,7 +7,6 @@ import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
_StridedShard,
MaskPartial,
Partial,
Placement,
Replicate,
@ -129,185 +127,6 @@ class DTensorSpec:
)
return default_shard_order
@staticmethod
def _convert_shard_order_to_StridedShard(
shard_order: ShardOrder, placements: tuple[Placement, ...], mesh: DeviceMesh
) -> tuple[Placement, ...]:
"""
Convert ShardOrder to placements with _StridedShard.
This function converts a ShardOrder specification into a tuple of Placement objects,
using _StridedShard when a tensor dimension is sharded across multiple mesh dimensions
in a non-default order. The split_factor of each _StridedShard is determined by the
product of mesh dimension sizes that appear earlier in the shard order but later in
the placement tuple.
Args:
shard_order: ShardOrder specification indicating which tensor dimensions are
sharded on which mesh dimensions and in what execution order.
placements: Tuple of Placement objects that does not contain _StridedShard.
mesh: DeviceMesh containing the size information for each mesh dimension.
Returns:
Updated tuple of Placement objects with Shard or _StridedShard placements.
Algorithm:
For each ShardOrderEntry in shard_order:
- For each mesh dimension in the entry's mesh_dims (in order):
- Calculate split_factor as the product of mesh sizes for all mesh dimensions
that appear:
1. Earlier in the shard order (lower index in mesh_dims), and
2. Later in the placement tuple (higher mesh dimension index)
- If split_factor == 1: use normal Shard
- Otherwise: use _StridedShard with the calculated split_factor
Example:
>>> # xdoctest: +SKIP("Requires DeviceMesh")
>>> # Tensor dimension 0 sharded on mesh dims [2, 0, 1] in that order
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
>>> shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),)
>>> placements = (Shard(0), Shard(0), Shard(0))
>>> # For mesh_dim=2 (index 0 in mesh_dims): no earlier dims, split_factor=1
>>> # -> placements[2] = Shard(0)
>>> # For mesh_dim=0 (index 1 in mesh_dims): mesh_dim=2 is earlier and has index 2>0
>>> # -> split_factor = mesh.size(2) = 2
>>> # -> placements[0] = _StridedShard(0, split_factor=2)
>>> # For mesh_dim=1 (index 2 in mesh_dims): mesh_dim=2 is earlier and has index 2>1
>>> # -> split_factor = mesh.size(2) = 2
>>> # -> placements[1] = _StridedShard(0, split_factor=2)
>>> # Result: (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
"""
placements_list = list(placements)
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for idx in range(len(mesh_dims)):
# TODO(zpcore): split_factor from `view` and `shard order`
# should be able to be multiplied into one. Need to loosen the
# condition here.
mesh_dim = mesh_dims[idx]
if type(placements[mesh_dim]) is not Shard:
raise ValueError(
f"Only Shard placement can be converted to _StridedShard, "
f"found {placements[mesh_dim]} in {placements=}."
)
split_factor = math.prod(
mesh.size(i) for i in mesh_dims[:idx] if i > mesh_dim
)
if split_factor == 1:
# use normal Shard
placements_list[mesh_dim] = Shard(tensor_dim)
else:
placements_list[mesh_dim] = _StridedShard(
tensor_dim, split_factor=split_factor
)
return tuple(placements_list)
@staticmethod
def _maybe_convert_StridedShard_to_shard_order(
placements: tuple[Placement, ...], mesh: DeviceMesh
) -> Optional[ShardOrder]:
"""
Try to convert _StridedShard placements to ShardOrder.
This is the inverse of `_convert_shard_order_to_StridedShard`. It reconstructs the shard
order by examining the split_factor of each _StridedShard and determining its position
in the execution order. If the _StridedShard configuration cannot be represented as a
valid ShardOrder (i.e., there's no shard order that produces the observed split_factors),
this function returns None.
Args:
placements: Tuple of Placement objects that may contain _StridedShard.
mesh: DeviceMesh containing the size information for each mesh dimension.
Returns:
ShardOrder if conversion is possible, None otherwise. For placements without
_StridedShard, returns the default shard order.
Algorithm:
1. If no _StridedShard in placements, return default shard order
2. Create an empty list for each tensor dimension to represent mesh dim ordering
3. Iterate through placements in reverse order (right to left):
- For each Shard/_StridedShard on a tensor dimension:
- Extract its split_factor (1 for Shard, split_factor for _StridedShard)
- Find the position in mesh_dims_order where accumulated_sf equals split_factor
- accumulated_sf is the product of mesh sizes of mesh dimensions that appear
earlier in mesh_dims_order (lower indices)
- Insert mesh_dim at the found position
4. If no valid position found for any split_factor, return None (unable to convert)
5. Construct ShardOrderEntry for each tensor dimension from mesh_dims_order
Example:
>>> # xdoctest: +SKIP("Requires DeviceMesh")
>>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2
>>> # placements = (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0))
>>> # Process tensor_dim=0 from right to left:
>>> # - mesh_dim=2: Shard(0) with sf=1
>>> # Try position 0: accumulated_sf=1, matches! Insert at position 0
>>> # Current mesh_dims_order order: [2]
>>> # - mesh_dim=1: _StridedShard(0, sf=2) with sf=2
>>> # Try position 0: accumulated_sf=1, no match
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
>>> # Current mesh_dims_order order: [2, 1]
>>> # - mesh_dim=0: _StridedShard(0, sf=2) with sf=2
>>> # Try position 0: accumulated_sf=1, no match
>>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1
>>> # Final mesh_dims_order order: [2, 0, 1]
>>> # Result: ShardOrder((ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),))
>>> # This means: first shard on mesh_dim=2, then mesh_dim=0, then mesh_dim=1
Note:
This function validates that _StridedShard can be represented as a ShardOrder.
Not all _StridedShard configurations are valid - the split_factor must match
the product of mesh sizes in some execution order.
"""
if not any(isinstance(p, _StridedShard) for p in placements):
return DTensorSpec.compute_default_shard_order(placements)
max_tensor_dim = (
max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1
)
shard_order = []
tensor_dim_to_mesh_dims_order: list[list[int]] = [
[] for i in range(max_tensor_dim)
]
for mesh_dim in reversed(range(len(placements))):
cur_placement = placements[mesh_dim]
# _StridedShard may not be a subclass of Shard in the future, so write in this way:
if isinstance(cur_placement, Shard | _StridedShard):
tensor_dim = cur_placement.dim
mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim]
cur_sf = 1
if isinstance(cur_placement, _StridedShard):
cur_sf = cur_placement.split_factor
accumulated_sf = 1
find_order = False
for i in range(len(mesh_dims_order) + 1):
if accumulated_sf == cur_sf:
mesh_dims_order.insert(i, mesh_dim)
find_order = True
break
if i < len(mesh_dims_order):
accumulated_sf *= mesh.size(mesh_dims_order[i])
if not find_order:
# _StridedShard is not convertible to ShardOrder
return None
else:
if not isinstance(cur_placement, Replicate | Partial | MaskPartial):
raise ValueError(
f"Unsupported placement type {type(cur_placement)} encountered in "
f"{placements}; expected Replicate, Partial, or MaskPartial."
)
for tensor_dim in range(max_tensor_dim):
if len(tensor_dim_to_mesh_dims_order[tensor_dim]) > 0:
shard_order.append(
ShardOrderEntry(
tensor_dim=tensor_dim,
mesh_dims=tuple(tensor_dim_to_mesh_dims_order[tensor_dim]),
)
)
return tuple(shard_order)
def _verify_shard_order(self, shard_order: ShardOrder) -> None:
"""Verify that the shard_order is valid and matches the placements."""
total_shard = 0

View File

@ -11,10 +11,7 @@ import torch.distributed.tensor._api as dtensor
aten = torch.ops.aten
def _requires_data_exchange(padding, dim_map) -> bool:
# Data exchange is not need if only sharded across batch dim
if all(x == -1 for x in dim_map[1:]):
return False
def _requires_data_exchange(padding):
# TODO: whether there requires data exchange is currently determined by padding
return padding[-1] != 0
@ -110,7 +107,6 @@ def tp_convolution(
op_call: torch._ops.OpOverload,
local_tensor_args: tuple[object, ...],
local_tensor_kwargs: dict[str, object],
dim_map: list[int],
) -> object:
assert op_call == aten.convolution.default
assert len(local_tensor_args) == 9
@ -124,7 +120,7 @@ def tp_convolution(
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, list)
if not _requires_data_exchange(padding, dim_map):
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
@ -164,7 +160,6 @@ def tp_convolution_backward(
op_call: torch._ops.OpOverload,
local_tensor_args: tuple[object, ...],
local_tensor_kwargs: dict[str, object],
dim_map: list[int],
) -> object:
assert op_call == aten.convolution_backward.default
assert len(local_tensor_args) == 11
@ -179,7 +174,7 @@ def tp_convolution_backward(
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, list)
if not _requires_data_exchange(padding, dim_map):
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
@ -244,18 +239,15 @@ def convolution_handler(
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
output_spec = output_sharding.output_spec
assert isinstance(output_spec, dtensor.DTensorSpec)
# local propagation
local_results = tp_convolution(
op_call,
tuple(op_info.local_args),
op_info.local_kwargs,
output_spec.dim_map,
op_call, tuple(op_info.local_args), op_info.local_kwargs
)
return dtensor.DTensor._op_dispatcher.wrap(local_results, output_spec)
return dtensor.DTensor._op_dispatcher.wrap(
local_results, output_sharding.output_spec
)
def convolution_backward_handler(
@ -278,14 +270,10 @@ def convolution_backward_handler(
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
assert isinstance(op_info.flat_args_schema[0], dtensor.DTensorSpec)
# local propagation
local_results = tp_convolution_backward(
op_call,
tuple(op_info.local_args),
op_info.local_kwargs,
op_info.flat_args_schema[0].dim_map,
op_call, tuple(op_info.local_args), op_info.local_kwargs
)
return dtensor.DTensor._op_dispatcher.wrap(

View File

@ -20320,7 +20320,6 @@ op_db: list[OpInfo] = [
torch.float32: 1e-4}),),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
supports_sparse=True,
supports_sparse_csr=True,
supports_sparse_csc=True,
supports_sparse_bsr=True,

View File

@ -3,7 +3,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import copy
import functools
import itertools
import sys
@ -33,8 +32,6 @@ from torch.distributed.tensor import (
Replicate,
Shard,
)
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
@ -821,125 +818,3 @@ def map_local_for_rank(rank, func):
def reduce_local_int(val, func):
return func(val.node._local_ints)
def _convert_shard_order_dict_to_ShardOrder(shard_order):
"""Convert shard_order dict to ShardOrder"""
return tuple(
ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims))
for tensor_dim, mesh_dims in shard_order.items()
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def redistribute(
dtensor_input,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""
wrapper function to support shard_order for redistribution
This is a simpler version of Redistribute, only considers the forward.
"""
if placements is None:
placements = shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
old_spec = dtensor_input._spec
new_spec = copy.deepcopy(old_spec)
new_spec.placements = placements
if shard_order is not None:
new_spec.shard_order = shard_order
else:
new_spec.shard_order = ()
if old_spec == new_spec:
return dtensor_input
dtensor_input = DTensor.from_local(
redistribute_local_tensor(
dtensor_input.to_local(),
old_spec,
new_spec,
use_graph_based_transform=use_graph_based_transform,
),
device_mesh,
)
dtensor_input._spec = copy.deepcopy(new_spec)
return dtensor_input # returns DTensor
# TODO(zpcore): remove once the native distribute_tensor supports
# shard_order arg
def patched_distribute_tensor(
input_tensor,
device_mesh,
placements,
shard_order,
use_graph_based_transform=True,
):
"""wrapper function to support shard_order for tensor distribution"""
if placements is None:
placements = shard_order_to_placement(shard_order, device_mesh)
placements = tuple(placements)
tensor_dt = distribute_tensor(input_tensor, device_mesh, placements)
# fix the shard order
return redistribute(
tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform
)
# TODO(zpcore): remove once the native redistribute supports shard_order arg
def make_full_tensor(dtensor_input):
"""wrapper function to support DTensor.full_tensor"""
return redistribute(
dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=()
).to_local()
def shard_order_to_placement(shard_order, mesh):
"""convert shard_order to placement with only Replicate() and Shard()"""
placements: list[Any] = [Replicate() for _ in range(mesh.ndim)]
if shard_order is not None:
for entry in shard_order:
tensor_dim = entry.tensor_dim
mesh_dims = entry.mesh_dims
for mesh_dim in mesh_dims:
placements[mesh_dim] = Shard(tensor_dim)
return tuple(placements)
def generate_shard_orders(mesh, tensor_rank):
# Generate all possible sharding placement of tensor with rank
# `tensor_rank` over mesh.
def _split_list(lst: list, N: int):
def compositions(n: int, k: int):
# yields lists of length k, positive ints summing to n
for cuts in itertools.combinations(range(1, n), k - 1):
# add 0 and n as sentinels, then take consecutive differences
yield [b - a for a, b in itertools.pairwise((0, *cuts, n))]
length = len(lst)
for comp in compositions(length, N):
result = []
start = 0
for size in comp:
result.append(lst[start : start + size])
start += size
yield result
all_mesh = list(range(mesh.ndim))
all_device_order = list(itertools.permutations(all_mesh))
for device_order in all_device_order:
# split on device orders, and assign each device order segment to a tensor dim
for num_split in range(1, mesh.ndim + 1):
for splitted_list in _split_list(list(range(mesh.ndim)), num_split):
for tensor_dims in itertools.combinations(
range(tensor_rank), len(splitted_list)
):
shard_order = {}
assert len(tensor_dims) == len(splitted_list)
for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list):
shard_order[tensor_dim] = device_order[
mesh_dims[0] : mesh_dims[-1] + 1
]
yield _convert_shard_order_dict_to_ShardOrder(shard_order)

View File

@ -529,14 +529,11 @@ RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+
def replace_extern_shared(input_string):
"""
Match 'extern __shared__ type foo[];' syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
See: https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
Examples:
"extern __shared__ char smemChar[];"
=> "HIP_DYNAMIC_SHARED( char, smemChar)"
"extern __shared__ unsigned char smem[];"
=> "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
"""Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
Example:
"extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
"extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
"""
output_string = input_string
output_string = RE_EXTERN_SHARED.sub(
@ -1046,17 +1043,14 @@ RE_INCLUDE = re.compile(r"#include .*\n")
def extract_arguments(start, string):
"""
Return the list of arguments in the upcoming function parameter closure.
Example:
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
arguments (output):
[
{'start': 1, 'end': 7},
{'start': 8, 'end': 16},
{'start': 17, 'end': 19},
{'start': 20, 'end': 53}
]
""" Return the list of arguments in the upcoming function parameter closure.
Example:
string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
arguments (output):
'[{'start': 1, 'end': 7},
{'start': 8, 'end': 16},
{'start': 17, 'end': 19},
{'start': 20, 'end': 53}]'
"""
arguments = []

View File

@ -190,7 +190,6 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
int: the memory available on the device in units of bytes.
int: the total memory on the device in units of bytes
"""
_lazy_init()
device = _get_device_index(device, optional=True)
return torch._C._xpu_getMemoryInfo(device)