mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Revert "Adds support for accelerated sorting with x86-simd-sort (#127936)"
This reverts commit 239a9ad65eebf93dcf9bb108a5129d4160b12c86.
Reverted https://github.com/pytorch/pytorch/pull/127936 on behalf of https://github.com/atalman due to test/test_sort_and_select.py::TestSortAndSelectCPU::test_sort_discontiguous_slow_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/10994904767/job/30525578456) [HUD commit link](239a9ad65e) ([comment](https://github.com/pytorch/pytorch/pull/127936#issuecomment-2368522316))
			
			
This commit is contained in:
		
							
								
								
									
										3
									
								
								.gitmodules
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitmodules
									
									
									
									
										vendored
									
									
								
							| @ -127,6 +127,3 @@ | |||||||
| [submodule "third_party/NVTX"] | [submodule "third_party/NVTX"] | ||||||
| 	path = third_party/NVTX | 	path = third_party/NVTX | ||||||
| 	url = https://github.com/NVIDIA/NVTX.git | 	url = https://github.com/NVIDIA/NVTX.git | ||||||
| [submodule "third_party/x86-simd-sort"] |  | ||||||
| 	path = third_party/x86-simd-sort |  | ||||||
| 	url = https://github.com/intel/x86-simd-sort.git |  | ||||||
|  | |||||||
| @ -262,7 +262,6 @@ else() | |||||||
|   cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF) |   cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF) | ||||||
| endif() | endif() | ||||||
| option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) | option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) | ||||||
| option(USE_X86_SIMD_SORT "Use x86-simd-sort to accelerate sorting and topk for AVX2/AVX512" ON) |  | ||||||
| option(USE_KINETO "Use Kineto profiling library" ON) | option(USE_KINETO "Use Kineto profiling library" ON) | ||||||
| option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) | option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) | ||||||
| option(USE_FAKELOWP "Use FakeLowp operators" OFF) | option(USE_FAKELOWP "Use FakeLowp operators" OFF) | ||||||
| @ -908,13 +907,6 @@ if(USE_FBGEMM) | |||||||
|   string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM") |   string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM") | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
| if(USE_X86_SIMD_SORT) |  | ||||||
|   string(APPEND CMAKE_CXX_FLAGS " -DUSE_X86_SIMD_SORT") |  | ||||||
|   if(USE_XSS_OPENMP) |  | ||||||
|     string(APPEND CMAKE_CXX_FLAGS " -DXSS_USE_OPENMP") |  | ||||||
|   endif() |  | ||||||
| endif() |  | ||||||
|  |  | ||||||
| if(USE_PYTORCH_QNNPACK) | if(USE_PYTORCH_QNNPACK) | ||||||
|   string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK") |   string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK") | ||||||
| endif() | endif() | ||||||
|  | |||||||
							
								
								
									
										34
									
								
								NOTICE
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								NOTICE
									
									
									
									
									
								
							| @ -454,37 +454,3 @@ and reference the following license: | |||||||
|     LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE |     LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE | ||||||
|     OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR |     OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR | ||||||
|     PERFORMANCE OF THIS SOFTWARE. |     PERFORMANCE OF THIS SOFTWARE. | ||||||
|  |  | ||||||
| ======================================================================= |  | ||||||
| x86-simd-sort BSD 3-Clause License |  | ||||||
| ======================================================================= |  | ||||||
|  |  | ||||||
| Code derived from implementations in x86-simd-sort should mention its |  | ||||||
| derivation and reference the following license: |  | ||||||
|  |  | ||||||
|    Copyright (c) 2022, Intel. All rights reserved. |  | ||||||
|  |  | ||||||
|    Redistribution and use in source and binary forms, with or without |  | ||||||
|    modification, are permitted provided that the following conditions are met: |  | ||||||
|  |  | ||||||
|    1. Redistributions of source code must retain the above copyright notice, this |  | ||||||
|       list of conditions and the following disclaimer. |  | ||||||
|  |  | ||||||
|    2. Redistributions in binary form must reproduce the above copyright notice, |  | ||||||
|       this list of conditions and the following disclaimer in the documentation |  | ||||||
|       and/or other materials provided with the distribution. |  | ||||||
|  |  | ||||||
|    3. Neither the name of the copyright holder nor the names of its |  | ||||||
|       contributors may be used to endorse or promote products derived from |  | ||||||
|       this software without specific prior written permission. |  | ||||||
|  |  | ||||||
|    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |  | ||||||
|    AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |  | ||||||
|    IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |  | ||||||
|    DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |  | ||||||
|    FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |  | ||||||
|    DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |  | ||||||
|    SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |  | ||||||
|    CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |  | ||||||
|    OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |  | ||||||
|    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |  | ||||||
| @ -15,18 +15,11 @@ | |||||||
| #include <ATen/native/CompositeRandomAccessor.h> | #include <ATen/native/CompositeRandomAccessor.h> | ||||||
| #include <ATen/native/TopKImpl.h> | #include <ATen/native/TopKImpl.h> | ||||||
| #include <c10/core/WrapDimMinimal.h> | #include <c10/core/WrapDimMinimal.h> | ||||||
| #include <c10/util/SmallBuffer.h> |  | ||||||
| #include <c10/util/irange.h> | #include <c10/util/irange.h> | ||||||
|  |  | ||||||
| #ifdef USE_FBGEMM | #ifdef USE_FBGEMM | ||||||
| #include <fbgemm/Utils.h> | #include <fbgemm/Utils.h> | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #if USE_X86_SIMD_SORT && (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)) |  | ||||||
| #define XSS_COMPILE_TIME_SUPPORTED |  | ||||||
| #include <src/x86simdsort-static-incl.h> |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| namespace at::native { | namespace at::native { | ||||||
|  |  | ||||||
| namespace { | namespace { | ||||||
| @ -126,7 +119,6 @@ static void parallel_sort1d_kernel( | |||||||
|     std::vector<int64_t> tmp_vals(elements); |     std::vector<int64_t> tmp_vals(elements); | ||||||
|     const scalar_t* sorted_keys = nullptr; |     const scalar_t* sorted_keys = nullptr; | ||||||
|     const int64_t* sorted_vals = nullptr; |     const int64_t* sorted_vals = nullptr; | ||||||
|  |  | ||||||
|     std::tie(sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel( |     std::tie(sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel( | ||||||
|         keys, |         keys, | ||||||
|         vals, |         vals, | ||||||
| @ -175,116 +167,6 @@ static inline void sort_kernel_impl(const value_accessor_t& value_accessor, | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| #if defined(XSS_COMPILE_TIME_SUPPORTED) |  | ||||||
|  |  | ||||||
| #define AT_DISPATCH_CASE_XSS_TYPES(...)          \ |  | ||||||
|   AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ |  | ||||||
|   AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ |  | ||||||
|   AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ |  | ||||||
|   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) |  | ||||||
|  |  | ||||||
| #define AT_DISPATCH_XSS_TYPES(TYPE, NAME, ...) \ |  | ||||||
|   AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_XSS_TYPES(__VA_ARGS__)) |  | ||||||
|  |  | ||||||
| static bool can_use_xss_sort(const TensorBase& values, const TensorBase& indices, int64_t dim, const bool stable) { |  | ||||||
|   // xss_sort is not a stable sort |  | ||||||
|   if (stable) return false; |  | ||||||
|  |  | ||||||
|   auto type = values.scalar_type(); |  | ||||||
|   if (! (type == ScalarType::Long || type == ScalarType::Int || type == ScalarType::Double || type == ScalarType::Float)) return false; |  | ||||||
|  |  | ||||||
|   return true; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static bool xss_sort_preferred(const TensorBase& values, const bool descending) { |  | ||||||
| #if defined(XSS_USE_OPENMP) || !defined(USE_FBGEMM) |  | ||||||
|     return true; |  | ||||||
| #else |  | ||||||
|     // Without OpenMP support for x86-simd-sort, fbgemm radix sort is faster when it can be used |  | ||||||
|     return !can_use_radix_sort(values, descending); |  | ||||||
| #endif |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static void xss_sort_kernel( |  | ||||||
|     const TensorBase& values, |  | ||||||
|     const TensorBase& indices, |  | ||||||
|     int64_t dim, |  | ||||||
|     bool descending) { |  | ||||||
|   auto iter = TensorIteratorConfig() |  | ||||||
|     .check_all_same_dtype(false) |  | ||||||
|     .resize_outputs(false) |  | ||||||
|     .declare_static_shape(values.sizes(), /*squash_dims=*/dim) |  | ||||||
|     .add_output(values) |  | ||||||
|     .add_output(indices) |  | ||||||
|     .build(); |  | ||||||
|  |  | ||||||
|   using index_t = int64_t; |  | ||||||
|  |  | ||||||
|   AT_DISPATCH_XSS_TYPES(values.scalar_type(), "xss_sort_kernel", [&] { |  | ||||||
|  |  | ||||||
|     auto values_dim_stride = values.stride(dim); |  | ||||||
|     auto indices_dim_stride = indices.stride(dim); |  | ||||||
|     auto dim_size = values.size(dim); |  | ||||||
|  |  | ||||||
|     auto loop = [&](char** data, const int64_t* strides, int64_t n) { |  | ||||||
|       auto* values_data_bytes = data[0]; |  | ||||||
|       auto* indices_data_bytes = data[1]; |  | ||||||
|  |  | ||||||
|       if(values_data_bytes==nullptr || indices_data_bytes==nullptr){ |  | ||||||
|         return; |  | ||||||
|       } |  | ||||||
|  |  | ||||||
|       if (values_dim_stride == 1 && indices_dim_stride == 1){ |  | ||||||
|         for (const auto i C10_UNUSED : c10::irange(n)) { |  | ||||||
|           x86simdsortStatic::keyvalue_qsort<scalar_t, index_t>( |  | ||||||
|               reinterpret_cast<scalar_t*>(values_data_bytes), |  | ||||||
|               reinterpret_cast<index_t*>(indices_data_bytes), |  | ||||||
|               dim_size, |  | ||||||
|               true, |  | ||||||
|               descending); |  | ||||||
|  |  | ||||||
|           values_data_bytes += strides[0]; |  | ||||||
|           indices_data_bytes += strides[1]; |  | ||||||
|         } |  | ||||||
|       }else{ |  | ||||||
|         c10::SmallBuffer<scalar_t, 0> tmp_values(dim_size); |  | ||||||
|         c10::SmallBuffer<index_t, 0> tmp_indices(dim_size); |  | ||||||
|  |  | ||||||
|         for (const auto i : c10::irange(n)) { |  | ||||||
|           TensorAccessor<scalar_t, 1> mode_values_acc( |  | ||||||
|               reinterpret_cast<scalar_t*>(data[0] + i * strides[0]), |  | ||||||
|               &dim_size, &values_dim_stride); |  | ||||||
|           TensorAccessor<index_t, 1> mode_indices_acc( |  | ||||||
|               reinterpret_cast<index_t*>(data[1] + i * strides[1]), |  | ||||||
|               &dim_size, &indices_dim_stride); |  | ||||||
|  |  | ||||||
|           for (const auto j : c10::irange(dim_size)) { |  | ||||||
|             tmp_values[j] = mode_values_acc[j]; |  | ||||||
|             tmp_indices[j] = j; |  | ||||||
|           } |  | ||||||
|  |  | ||||||
|           x86simdsortStatic::keyvalue_qsort<scalar_t, index_t>( |  | ||||||
|               tmp_values.data(), |  | ||||||
|               tmp_indices.data(), |  | ||||||
|               dim_size, |  | ||||||
|               true, |  | ||||||
|               descending); |  | ||||||
|  |  | ||||||
|           for (const auto j : c10::irange(dim_size)) { |  | ||||||
|             mode_values_acc[j] = tmp_values[j]; |  | ||||||
|             mode_indices_acc[j] = tmp_indices[j]; |  | ||||||
|           } |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, dim_size); |  | ||||||
|     iter.for_each(loop, /*grain_size=*/grain_size); |  | ||||||
|  |  | ||||||
|   }); |  | ||||||
| } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| static void sort_kernel( | static void sort_kernel( | ||||||
|     const TensorBase& self, |     const TensorBase& self, | ||||||
|     const TensorBase& values, |     const TensorBase& values, | ||||||
| @ -299,14 +181,6 @@ static void sort_kernel( | |||||||
|     // https://github.com/pytorch/pytorch/issues/91420 |     // https://github.com/pytorch/pytorch/issues/91420 | ||||||
|     return; |     return; | ||||||
|   } |   } | ||||||
|  |  | ||||||
| #if defined(XSS_COMPILE_TIME_SUPPORTED) |  | ||||||
|   if (can_use_xss_sort(values, indices, dim, stable) && xss_sort_preferred(values, descending)){ |  | ||||||
|     xss_sort_kernel(values, indices, dim, descending); |  | ||||||
|     return; |  | ||||||
|   } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #ifdef USE_FBGEMM | #ifdef USE_FBGEMM | ||||||
|   if (can_use_radix_sort(values, descending)) { |   if (can_use_radix_sort(values, descending)) { | ||||||
|     parallel_sort1d_kernel(values, indices); |     parallel_sort1d_kernel(values, indices); | ||||||
| @ -358,7 +232,6 @@ static void topk_kernel( | |||||||
|     int64_t dim, |     int64_t dim, | ||||||
|     bool largest, |     bool largest, | ||||||
|     bool sorted) { |     bool sorted) { | ||||||
|  |  | ||||||
|   auto sizes = self.sizes(); |   auto sizes = self.sizes(); | ||||||
|   auto iter = TensorIteratorConfig() |   auto iter = TensorIteratorConfig() | ||||||
|     .check_all_same_dtype(false) |     .check_all_same_dtype(false) | ||||||
| @ -393,7 +266,7 @@ static void topk_kernel( | |||||||
|  |  | ||||||
| } // anonymous namespace | } // anonymous namespace | ||||||
|  |  | ||||||
| ALSO_REGISTER_AVX512_DISPATCH(sort_stub, &sort_kernel); | REGISTER_DISPATCH(sort_stub, &sort_kernel); | ||||||
| ALSO_REGISTER_AVX512_DISPATCH(topk_stub, &topk_kernel); | REGISTER_DISPATCH(topk_stub, &topk_kernel); | ||||||
|  |  | ||||||
| } //at::native | } //at::native | ||||||
|  | |||||||
| @ -1328,28 +1328,6 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX) | |||||||
|   set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS}) |   set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS}) | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
| # --[ x86-simd-sort integration |  | ||||||
| if(USE_X86_SIMD_SORT) |  | ||||||
|   if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) |  | ||||||
|     message(WARNING |  | ||||||
|       "x64 operating system is required for x86-simd-sort. " |  | ||||||
|       "Not compiling with x86-simd-sort. " |  | ||||||
|       "Turn this warning off by USE_X86_SIMD_SORT=OFF.") |  | ||||||
|     set(USE_X86_SIMD_SORT OFF) |  | ||||||
|   endif() |  | ||||||
|  |  | ||||||
|   if(USE_X86_SIMD_SORT) |  | ||||||
|     if(USE_OPENMP AND NOT MSVC) |  | ||||||
|       set(USE_XSS_OPENMP ON) |  | ||||||
|     else() |  | ||||||
|       set(USE_XSS_OPENMP OFF) |  | ||||||
|     endif() |  | ||||||
|  |  | ||||||
|     set(XSS_SIMD_SORT_INCLUDE_DIR ${CMAKE_CURRENT_LIST_DIR}/../third_party/x86-simd-sort) |  | ||||||
|     include_directories(SYSTEM ${XSS_SIMD_SORT_INCLUDE_DIR}) |  | ||||||
|   endif() |  | ||||||
| endif() |  | ||||||
|  |  | ||||||
| # --[ ATen checks | # --[ ATen checks | ||||||
| set(USE_LAPACK 0) | set(USE_LAPACK 0) | ||||||
|  |  | ||||||
|  | |||||||
| @ -133,7 +133,6 @@ function(caffe2_print_configuration_summary) | |||||||
|   endif() |   endif() | ||||||
|   message(STATUS "  BUILD_NVFUSER         : ${BUILD_NVFUSER}") |   message(STATUS "  BUILD_NVFUSER         : ${BUILD_NVFUSER}") | ||||||
|   message(STATUS "  USE_EIGEN_FOR_BLAS    : ${CAFFE2_USE_EIGEN_FOR_BLAS}") |   message(STATUS "  USE_EIGEN_FOR_BLAS    : ${CAFFE2_USE_EIGEN_FOR_BLAS}") | ||||||
|   message(STATUS "  USE_X86_SIMD_SORT     : ${USE_X86_SIMD_SORT}") |  | ||||||
|   message(STATUS "  USE_FBGEMM            : ${USE_FBGEMM}") |   message(STATUS "  USE_FBGEMM            : ${USE_FBGEMM}") | ||||||
|   message(STATUS "    USE_FAKELOWP          : ${USE_FAKELOWP}") |   message(STATUS "    USE_FAKELOWP          : ${USE_FAKELOWP}") | ||||||
|   message(STATUS "  USE_KINETO            : ${USE_KINETO}") |   message(STATUS "  USE_KINETO            : ${USE_KINETO}") | ||||||
|  | |||||||
| @ -466,9 +466,6 @@ inductor_override_kwargs["cpu"] = { | |||||||
|     ("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0}, |     ("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0}, | ||||||
|     # High atol due to precision loss |     # High atol due to precision loss | ||||||
|     ("nn.functional.interpolate.bicubic", f32): {"atol": 5e-3, "rtol": 0}, |     ("nn.functional.interpolate.bicubic", f32): {"atol": 5e-3, "rtol": 0}, | ||||||
|     # reference_in_float can cause erroneous failures in sorting tests |  | ||||||
|     "argsort": {"reference_in_float": False}, |  | ||||||
|     "sort": {"reference_in_float": False}, |  | ||||||
| } | } | ||||||
|  |  | ||||||
| inductor_override_kwargs["cuda"] = { | inductor_override_kwargs["cuda"] = { | ||||||
| @ -539,9 +536,6 @@ inductor_override_kwargs["cuda"] = { | |||||||
|     ("index_reduce.amax", f32): {"check_gradient": False}, |     ("index_reduce.amax", f32): {"check_gradient": False}, | ||||||
|     ("index_reduce.amax", f16): {"check_gradient": False}, |     ("index_reduce.amax", f16): {"check_gradient": False}, | ||||||
|     ("tanh", f16): {"atol": 1e-4, "rtol": 1e-2}, |     ("tanh", f16): {"atol": 1e-4, "rtol": 1e-2}, | ||||||
|     # reference_in_float can cause erroneous failures in sorting tests |  | ||||||
|     "argsort": {"reference_in_float": False}, |  | ||||||
|     "sort": {"reference_in_float": False}, |  | ||||||
| } | } | ||||||
|  |  | ||||||
| inductor_override_kwargs["xpu"] = { | inductor_override_kwargs["xpu"] = { | ||||||
| @ -661,9 +655,6 @@ inductor_override_kwargs["xpu"] = { | |||||||
|     ("nn.functional.embedding_bag", f64): {"check_gradient": False}, |     ("nn.functional.embedding_bag", f64): {"check_gradient": False}, | ||||||
|     ("_unsafe_masked_index", f16): {"atol": 1e-5, "rtol": 2e-3}, |     ("_unsafe_masked_index", f16): {"atol": 1e-5, "rtol": 2e-3}, | ||||||
|     ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-5, "rtol": 5e-3}, |     ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-5, "rtol": 5e-3}, | ||||||
|     # reference_in_float can cause erroneous failures in sorting tests |  | ||||||
|     "argsort": {"reference_in_float": False}, |  | ||||||
|     "sort": {"reference_in_float": False}, |  | ||||||
| } | } | ||||||
|  |  | ||||||
| # Test with one sample only for following ops | # Test with one sample only for following ops | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								third_party/x86-simd-sort
									
									
									
									
										vendored
									
									
								
							
							
								
								
								
								
								
							
						
						
									
										1
									
								
								third_party/x86-simd-sort
									
									
									
									
										vendored
									
									
								
							 Submodule third_party/x86-simd-sort deleted from 9a1b616d5c
									
								
							
		Reference in New Issue
	
	Block a user