mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Looks like the original PR caused: https://github.com/pytorch/pytorch/issues/140590 Please see comment: https://github.com/pytorch/pytorch/issues/140590#issuecomment-2508704480 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141901 Approved by: https://github.com/andrewor14, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
e41a0b33ec
commit
c17ba69ba5
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -131,6 +131,3 @@
|
|||||||
path = third_party/composable_kernel
|
path = third_party/composable_kernel
|
||||||
url = https://github.com/ROCm/composable_kernel.git
|
url = https://github.com/ROCm/composable_kernel.git
|
||||||
branch = develop
|
branch = develop
|
||||||
[submodule "third_party/x86-simd-sort"]
|
|
||||||
path = third_party/x86-simd-sort
|
|
||||||
url = https://github.com/intel/x86-simd-sort.git
|
|
||||||
|
@ -262,7 +262,6 @@ else()
|
|||||||
cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF)
|
cmake_dependent_option(USE_CUFILE "Use cuFile" OFF "USE_CUDA AND NOT WIN32" OFF)
|
||||||
endif()
|
endif()
|
||||||
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
|
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
|
||||||
option(USE_X86_SIMD_SORT "Use x86-simd-sort to accelerate sorting and topk for AVX2/AVX512" ON)
|
|
||||||
option(USE_KINETO "Use Kineto profiling library" ON)
|
option(USE_KINETO "Use Kineto profiling library" ON)
|
||||||
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
|
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
|
||||||
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
|
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
|
||||||
@ -904,13 +903,6 @@ if(USE_FBGEMM)
|
|||||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM")
|
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_X86_SIMD_SORT)
|
|
||||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_X86_SIMD_SORT")
|
|
||||||
if(USE_XSS_OPENMP)
|
|
||||||
string(APPEND CMAKE_CXX_FLAGS " -DXSS_USE_OPENMP")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(USE_PYTORCH_QNNPACK)
|
if(USE_PYTORCH_QNNPACK)
|
||||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK")
|
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK")
|
||||||
endif()
|
endif()
|
||||||
|
34
NOTICE
34
NOTICE
@ -454,37 +454,3 @@ and reference the following license:
|
|||||||
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
|
||||||
OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
||||||
PERFORMANCE OF THIS SOFTWARE.
|
PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
=======================================================================
|
|
||||||
x86-simd-sort BSD 3-Clause License
|
|
||||||
=======================================================================
|
|
||||||
|
|
||||||
Code derived from implementations in x86-simd-sort should mention its
|
|
||||||
derivation and reference the following license:
|
|
||||||
|
|
||||||
Copyright (c) 2022, Intel. All rights reserved.
|
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions are met:
|
|
||||||
|
|
||||||
1. Redistributions of source code must retain the above copyright notice, this
|
|
||||||
list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
this list of conditions and the following disclaimer in the documentation
|
|
||||||
and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
3. Neither the name of the copyright holder nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived from
|
|
||||||
this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
||||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
||||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
||||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -15,18 +15,11 @@
|
|||||||
#include <ATen/native/CompositeRandomAccessor.h>
|
#include <ATen/native/CompositeRandomAccessor.h>
|
||||||
#include <ATen/native/TopKImpl.h>
|
#include <ATen/native/TopKImpl.h>
|
||||||
#include <c10/core/WrapDimMinimal.h>
|
#include <c10/core/WrapDimMinimal.h>
|
||||||
#include <c10/util/SmallBuffer.h>
|
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
#include <fbgemm/Utils.h>
|
#include <fbgemm/Utils.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if USE_X86_SIMD_SORT && (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2))
|
|
||||||
#define XSS_COMPILE_TIME_SUPPORTED
|
|
||||||
#include <src/x86simdsort-static-incl.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -124,7 +117,6 @@ static void parallel_sort1d_kernel(
|
|||||||
std::vector<int64_t> tmp_vals(elements);
|
std::vector<int64_t> tmp_vals(elements);
|
||||||
const scalar_t* sorted_keys = nullptr;
|
const scalar_t* sorted_keys = nullptr;
|
||||||
const int64_t* sorted_vals = nullptr;
|
const int64_t* sorted_vals = nullptr;
|
||||||
|
|
||||||
std::tie(sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel(
|
std::tie(sorted_keys, sorted_vals) = fbgemm::radix_sort_parallel(
|
||||||
keys,
|
keys,
|
||||||
vals,
|
vals,
|
||||||
@ -173,116 +165,6 @@ static inline void sort_kernel_impl(const value_accessor_t& value_accessor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(XSS_COMPILE_TIME_SUPPORTED)
|
|
||||||
|
|
||||||
#define AT_DISPATCH_CASE_XSS_TYPES(...) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define AT_DISPATCH_XSS_TYPES(TYPE, NAME, ...) \
|
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_XSS_TYPES(__VA_ARGS__))
|
|
||||||
|
|
||||||
static bool can_use_xss_sort(const TensorBase& values, const TensorBase& indices, int64_t dim, const bool stable) {
|
|
||||||
// xss_sort is not a stable sort
|
|
||||||
if (stable) return false;
|
|
||||||
|
|
||||||
auto type = values.scalar_type();
|
|
||||||
if (! (type == ScalarType::Long || type == ScalarType::Int || type == ScalarType::Double || type == ScalarType::Float)) return false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool xss_sort_preferred(const TensorBase& values, const bool descending) {
|
|
||||||
#if defined(XSS_USE_OPENMP) || !defined(USE_FBGEMM)
|
|
||||||
return true;
|
|
||||||
#else
|
|
||||||
// Without OpenMP support for x86-simd-sort, fbgemm radix sort is faster when it can be used
|
|
||||||
return !can_use_radix_sort(values, descending);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
static void xss_sort_kernel(
|
|
||||||
const TensorBase& values,
|
|
||||||
const TensorBase& indices,
|
|
||||||
int64_t dim,
|
|
||||||
bool descending) {
|
|
||||||
auto iter = TensorIteratorConfig()
|
|
||||||
.check_all_same_dtype(false)
|
|
||||||
.resize_outputs(false)
|
|
||||||
.declare_static_shape(values.sizes(), /*squash_dims=*/dim)
|
|
||||||
.add_output(values)
|
|
||||||
.add_output(indices)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
using index_t = int64_t;
|
|
||||||
|
|
||||||
AT_DISPATCH_XSS_TYPES(values.scalar_type(), "xss_sort_kernel", [&] {
|
|
||||||
|
|
||||||
auto values_dim_stride = values.stride(dim);
|
|
||||||
auto indices_dim_stride = indices.stride(dim);
|
|
||||||
auto dim_size = values.size(dim);
|
|
||||||
|
|
||||||
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
|
|
||||||
auto* values_data_bytes = data[0];
|
|
||||||
auto* indices_data_bytes = data[1];
|
|
||||||
|
|
||||||
if(values_data_bytes==nullptr || indices_data_bytes==nullptr){
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (values_dim_stride == 1 && indices_dim_stride == 1){
|
|
||||||
for (const auto i [[maybe_unused]] : c10::irange(n)) {
|
|
||||||
x86simdsortStatic::keyvalue_qsort<scalar_t, index_t>(
|
|
||||||
reinterpret_cast<scalar_t*>(values_data_bytes),
|
|
||||||
reinterpret_cast<index_t*>(indices_data_bytes),
|
|
||||||
dim_size,
|
|
||||||
true,
|
|
||||||
descending);
|
|
||||||
|
|
||||||
values_data_bytes += strides[0];
|
|
||||||
indices_data_bytes += strides[1];
|
|
||||||
}
|
|
||||||
}else{
|
|
||||||
c10::SmallBuffer<scalar_t, 0> tmp_values(dim_size);
|
|
||||||
c10::SmallBuffer<index_t, 0> tmp_indices(dim_size);
|
|
||||||
|
|
||||||
for (const auto i : c10::irange(n)) {
|
|
||||||
TensorAccessor<scalar_t, 1> mode_values_acc(
|
|
||||||
reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
|
|
||||||
&dim_size, &values_dim_stride);
|
|
||||||
TensorAccessor<index_t, 1> mode_indices_acc(
|
|
||||||
reinterpret_cast<index_t*>(data[1] + i * strides[1]),
|
|
||||||
&dim_size, &indices_dim_stride);
|
|
||||||
|
|
||||||
for (const auto j : c10::irange(dim_size)) {
|
|
||||||
tmp_values[j] = mode_values_acc[j];
|
|
||||||
tmp_indices[j] = j;
|
|
||||||
}
|
|
||||||
|
|
||||||
x86simdsortStatic::keyvalue_qsort<scalar_t, index_t>(
|
|
||||||
tmp_values.data(),
|
|
||||||
tmp_indices.data(),
|
|
||||||
dim_size,
|
|
||||||
true,
|
|
||||||
descending);
|
|
||||||
|
|
||||||
for (const auto j : c10::irange(dim_size)) {
|
|
||||||
mode_values_acc[j] = tmp_values[j];
|
|
||||||
mode_indices_acc[j] = tmp_indices[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, dim_size);
|
|
||||||
iter.for_each(loop, /*grain_size=*/grain_size);
|
|
||||||
|
|
||||||
});
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
static void sort_kernel(
|
static void sort_kernel(
|
||||||
const TensorBase& self,
|
const TensorBase& self,
|
||||||
const TensorBase& values,
|
const TensorBase& values,
|
||||||
@ -297,14 +179,6 @@ static void sort_kernel(
|
|||||||
// https://github.com/pytorch/pytorch/issues/91420
|
// https://github.com/pytorch/pytorch/issues/91420
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(XSS_COMPILE_TIME_SUPPORTED)
|
|
||||||
if (can_use_xss_sort(values, indices, dim, stable) && xss_sort_preferred(values, descending)){
|
|
||||||
xss_sort_kernel(values, indices, dim, descending);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
if (can_use_radix_sort(values, descending)) {
|
if (can_use_radix_sort(values, descending)) {
|
||||||
parallel_sort1d_kernel(values, indices);
|
parallel_sort1d_kernel(values, indices);
|
||||||
@ -356,7 +230,6 @@ static void topk_kernel(
|
|||||||
int64_t dim,
|
int64_t dim,
|
||||||
bool largest,
|
bool largest,
|
||||||
bool sorted) {
|
bool sorted) {
|
||||||
|
|
||||||
auto sizes = self.sizes();
|
auto sizes = self.sizes();
|
||||||
auto iter = TensorIteratorConfig()
|
auto iter = TensorIteratorConfig()
|
||||||
.check_all_same_dtype(false)
|
.check_all_same_dtype(false)
|
||||||
@ -391,7 +264,7 @@ static void topk_kernel(
|
|||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
ALSO_REGISTER_AVX512_DISPATCH(sort_stub, &sort_kernel)
|
REGISTER_DISPATCH(sort_stub, &sort_kernel)
|
||||||
ALSO_REGISTER_AVX512_DISPATCH(topk_stub, &topk_kernel)
|
REGISTER_DISPATCH(topk_stub, &topk_kernel)
|
||||||
|
|
||||||
} //at::native
|
} //at::native
|
||||||
|
@ -1310,28 +1310,6 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX)
|
|||||||
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS})
|
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# --[ x86-simd-sort integration
|
|
||||||
if(USE_X86_SIMD_SORT)
|
|
||||||
if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
|
|
||||||
message(WARNING
|
|
||||||
"x64 operating system is required for x86-simd-sort. "
|
|
||||||
"Not compiling with x86-simd-sort. "
|
|
||||||
"Turn this warning off by USE_X86_SIMD_SORT=OFF.")
|
|
||||||
set(USE_X86_SIMD_SORT OFF)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(USE_X86_SIMD_SORT)
|
|
||||||
if(USE_OPENMP AND NOT MSVC)
|
|
||||||
set(USE_XSS_OPENMP ON)
|
|
||||||
else()
|
|
||||||
set(USE_XSS_OPENMP OFF)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(XSS_SIMD_SORT_INCLUDE_DIR ${CMAKE_CURRENT_LIST_DIR}/../third_party/x86-simd-sort)
|
|
||||||
include_directories(SYSTEM ${XSS_SIMD_SORT_INCLUDE_DIR})
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# --[ ATen checks
|
# --[ ATen checks
|
||||||
set(USE_LAPACK 0)
|
set(USE_LAPACK 0)
|
||||||
|
|
||||||
|
@ -134,7 +134,6 @@ function(caffe2_print_configuration_summary)
|
|||||||
endif()
|
endif()
|
||||||
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
|
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
|
||||||
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
||||||
message(STATUS " USE_X86_SIMD_SORT : ${USE_X86_SIMD_SORT}")
|
|
||||||
message(STATUS " USE_FBGEMM : ${USE_FBGEMM}")
|
message(STATUS " USE_FBGEMM : ${USE_FBGEMM}")
|
||||||
message(STATUS " USE_FAKELOWP : ${USE_FAKELOWP}")
|
message(STATUS " USE_FAKELOWP : ${USE_FAKELOWP}")
|
||||||
message(STATUS " USE_KINETO : ${USE_KINETO}")
|
message(STATUS " USE_KINETO : ${USE_KINETO}")
|
||||||
|
@ -463,9 +463,6 @@ inductor_override_kwargs["cpu"] = {
|
|||||||
("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0},
|
("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0},
|
||||||
# High atol due to precision loss
|
# High atol due to precision loss
|
||||||
("nn.functional.interpolate.bicubic", f32): {"atol": 5e-3, "rtol": 0},
|
("nn.functional.interpolate.bicubic", f32): {"atol": 5e-3, "rtol": 0},
|
||||||
# reference_in_float can cause erroneous failures in sorting tests
|
|
||||||
"argsort": {"reference_in_float": False},
|
|
||||||
"sort": {"reference_in_float": False},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inductor_override_kwargs["cuda"] = {
|
inductor_override_kwargs["cuda"] = {
|
||||||
@ -536,9 +533,6 @@ inductor_override_kwargs["cuda"] = {
|
|||||||
("index_reduce.amax", f32): {"check_gradient": False},
|
("index_reduce.amax", f32): {"check_gradient": False},
|
||||||
("index_reduce.amax", f16): {"check_gradient": False},
|
("index_reduce.amax", f16): {"check_gradient": False},
|
||||||
("tanh", f16): {"atol": 1e-4, "rtol": 1e-2},
|
("tanh", f16): {"atol": 1e-4, "rtol": 1e-2},
|
||||||
# reference_in_float can cause erroneous failures in sorting tests
|
|
||||||
"argsort": {"reference_in_float": False},
|
|
||||||
"sort": {"reference_in_float": False},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inductor_override_kwargs["xpu"] = {
|
inductor_override_kwargs["xpu"] = {
|
||||||
@ -663,9 +657,6 @@ inductor_override_kwargs["xpu"] = {
|
|||||||
("nn.functional.embedding_bag", f64): {"check_gradient": False},
|
("nn.functional.embedding_bag", f64): {"check_gradient": False},
|
||||||
("_unsafe_masked_index", f16): {"atol": 1e-5, "rtol": 2e-3},
|
("_unsafe_masked_index", f16): {"atol": 1e-5, "rtol": 2e-3},
|
||||||
("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-5, "rtol": 5e-3},
|
("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-5, "rtol": 5e-3},
|
||||||
# reference_in_float can cause erroneous failures in sorting tests
|
|
||||||
"argsort": {"reference_in_float": False},
|
|
||||||
"sort": {"reference_in_float": False},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Test with one sample only for following ops
|
# Test with one sample only for following ops
|
||||||
|
1
third_party/x86-simd-sort
vendored
1
third_party/x86-simd-sort
vendored
Submodule third_party/x86-simd-sort deleted from f99c392904
Reference in New Issue
Block a user