mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[2/N] Enable UBSAN tests (#141740)
Apply c10::load in more places. The function was introduced to cast a byte to valid boolean values, thus fixing the UBSAN errors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141740 Approved by: https://github.com/ezyang
This commit is contained in:
@ -1047,7 +1047,7 @@ TORCH_IMPL_FUNC(index_add_cpu_out)
|
||||
auto self_i = index_data[i];
|
||||
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
|
||||
scalar_t *self_ip = result_ptr + self_i * result_stride;
|
||||
*self_ip += *(source_ptr + i * source_stride) * alpha_value;
|
||||
*self_ip += c10::load(source_ptr + i * source_stride) * alpha_value;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
@ -61,7 +61,7 @@ void apply_triu_tril_single(
|
||||
}
|
||||
if (!inplace) { // copy the rest of the self if not inplace
|
||||
for (int64_t j = std::max(zero, i + k); j < m; j++) {
|
||||
result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
|
||||
result[i * res_row_stride + j * res_col_stride] = c10::load(&self[i * self_row_stride + j * self_col_stride]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -74,7 +74,7 @@ void apply_triu_tril_single(
|
||||
}
|
||||
if (!inplace) { // copy the rest of the self if not inplace
|
||||
for (int64_t j = zero; j < std::min(m, i + k + 1); j++) {
|
||||
result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
|
||||
result[i * res_row_stride + j * res_col_stride] = c10::load(&self[i * self_row_stride + j * self_col_stride]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ void cpu_channel_shuffle(
|
||||
data_vec.store(output_ptr + d);
|
||||
}
|
||||
for (; d < image_size; d++) {
|
||||
output_ptr[d] = input_ptr[d];
|
||||
output_ptr[d] = c10::load(&(input_ptr[d]));
|
||||
}
|
||||
|
||||
// move on to next output index
|
||||
|
@ -25,7 +25,7 @@ void index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
|
||||
iter.dtype(), "index_cpu", [&] {
|
||||
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)dst = *(scalar_t*)(src + offset);
|
||||
*(scalar_t*)dst = c10::load((scalar_t*)(src + offset));
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -128,14 +128,14 @@ void put_kernel(
|
||||
// Unlike the non-accumulate case, this needs to be thread-safe.
|
||||
cpu_take_put_kernel<scalar_t>(iter, self, true,
|
||||
[](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
|
||||
indexed[idx] += iterated;
|
||||
indexed[idx] += c10::load(&iterated);
|
||||
},
|
||||
/*serial_execution=*/true);
|
||||
}
|
||||
} else {
|
||||
cpu_take_put_kernel<scalar_t>(iter, self, true,
|
||||
[](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
|
||||
indexed[idx] = iterated;
|
||||
indexed[idx] = c10::load(&iterated);
|
||||
});
|
||||
}
|
||||
});
|
||||
@ -148,7 +148,7 @@ void take_kernel(
|
||||
iter.dtype(), "take_cpu", [&] {
|
||||
cpu_take_put_kernel<scalar_t>(iter, input, false,
|
||||
[](scalar_t& iterated, const scalar_t* indexed, const int64_t idx) {
|
||||
iterated = indexed[idx];
|
||||
iterated = c10::load(&(indexed[idx]));
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -174,12 +174,12 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
|
||||
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
|
||||
// this needs to be thread-safe.
|
||||
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)(dst + offset) += *(scalar_t*)src;
|
||||
*(scalar_t*)(dst + offset) += c10::load(reinterpret_cast<scalar_t*>(src));
|
||||
}, /*serial_execution=*/true);
|
||||
}
|
||||
} else {
|
||||
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
|
||||
*(scalar_t*)(dst + offset) = c10::load(reinterpret_cast<scalar_t*>(src));
|
||||
}, /*serial_execution=*/is_deterministic);
|
||||
}
|
||||
}),
|
||||
@ -270,7 +270,7 @@ void index_copy_kernel(
|
||||
"index_copy_(): index ", idx, " is out of bounds for dimension ",
|
||||
dim, " with size ", self_dim_size);
|
||||
|
||||
self_data[idx * self_dim_stride] = *source_data;
|
||||
self_data[idx * self_dim_stride] = c10::load(source_data);
|
||||
|
||||
self_data_bytes += strides[0];
|
||||
index_data_bytes += strides[1];
|
||||
@ -289,7 +289,7 @@ void index_copy_kernel(
|
||||
auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
|
||||
auto* source_data = reinterpret_cast<scalar_t*>(source_data_bytes);
|
||||
|
||||
self_data[idx * self_dim_stride] = *source_data;
|
||||
self_data[idx * self_dim_stride] = c10::load(source_data);
|
||||
|
||||
self_data_bytes += strides[0];
|
||||
source_data_bytes += strides[2];
|
||||
@ -320,7 +320,7 @@ void cpu_masked_fill_kernel(TensorIterator& iter, scalar_t value) {
|
||||
char* dst = data[0];
|
||||
char* mask = data[1];
|
||||
for (const auto i : c10::irange(n)) {
|
||||
bool mask_value = *reinterpret_cast<bool*>(mask + strides[1] * i);
|
||||
bool mask_value = c10::load(reinterpret_cast<bool*>(mask + strides[1] * i));
|
||||
|
||||
if (mask_value) {
|
||||
*(scalar_t*)(dst + strides[0] * i) = value;
|
||||
@ -353,10 +353,11 @@ void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
|
||||
char* mask = data[1];
|
||||
const int64_t mask_stride = strides[1];
|
||||
for (const auto i : c10::irange(n)) {
|
||||
auto mask_value = *reinterpret_cast<bool*>(mask + mask_stride * i);
|
||||
auto mask_value = c10::load(reinterpret_cast<bool*>(mask + mask_stride * i));
|
||||
|
||||
if (mask_value) {
|
||||
TORCH_CHECK(source_cntr < numel, "Number of elements of source < number of ones in mask");
|
||||
*(scalar_t*)(dst + dst_stride * i) = *(source_ptr);
|
||||
*(scalar_t*)(dst + dst_stride * i) = c10::load(source_ptr);
|
||||
source_ptr++;
|
||||
source_cntr++;
|
||||
}
|
||||
@ -387,7 +388,7 @@ void cpu_masked_select_serial_kernel(TensorIterator& iter, const func_t& f) {
|
||||
char* src = data[1];
|
||||
char* mask = data[2];
|
||||
for (const auto i : c10::irange(n)) {
|
||||
mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
|
||||
mask_t mask_value = c10::load((mask_t*)(mask + strides[2] * i));
|
||||
if constexpr (!std::is_same_v<mask_t, bool>) {
|
||||
TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
|
||||
}
|
||||
@ -406,11 +407,11 @@ void masked_select_serial_kernel(TensorIterator& iter, int64_t result_stride) {
|
||||
auto mask_dtype = iter.input_dtype(1);
|
||||
if (mask_dtype == ScalarType::Bool) {
|
||||
cpu_masked_select_serial_kernel<scalar_t, bool>(iter, [result_stride](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
|
||||
*(scalar_t*)(dst + offset*result_stride) = c10::load((scalar_t*)src);
|
||||
});
|
||||
} else {
|
||||
cpu_masked_select_serial_kernel<scalar_t, unsigned char>(iter, [result_stride](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
|
||||
*(scalar_t*)(dst + offset*result_stride) = c10::load((scalar_t*)src);
|
||||
});
|
||||
}
|
||||
}),
|
||||
@ -430,7 +431,7 @@ void cpu_masked_select_kernel(TensorIterator& iter, const func_t& f) {
|
||||
char* mask = data[2];
|
||||
char* mask_prefix_sum = data[3];
|
||||
for (const auto i : c10::irange(n)) {
|
||||
mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
|
||||
mask_t mask_value = c10::load((mask_t*)(mask + strides[2] * i));
|
||||
if constexpr (!std::is_same_v<mask_t, bool>) {
|
||||
TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
|
||||
}
|
||||
@ -449,7 +450,7 @@ void masked_select_kernel(TensorIterator& iter, int64_t result_stride) {
|
||||
auto mask_dtype = iter.input_dtype(1);
|
||||
if (mask_dtype == ScalarType::Bool) {
|
||||
cpu_masked_select_kernel<scalar_t, bool>(iter, [result_stride](char* dst, char* src, int64_t offset) {
|
||||
*(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
|
||||
*(scalar_t*)(dst + offset*result_stride) = c10::load((scalar_t*)src);
|
||||
});
|
||||
} else {
|
||||
cpu_masked_select_kernel<scalar_t, unsigned char>(iter, [result_stride](char* dst, char* src, int64_t offset) {
|
||||
@ -501,7 +502,7 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
|
||||
offset = (offset >= n) ? n : offset;
|
||||
for (; i < offset; i++) {
|
||||
scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
|
||||
*out_ptr = *(scalar_t *)(data[1] + i * stride);
|
||||
*out_ptr = c10::load((scalar_t *)(data[1] + i * stride));
|
||||
}
|
||||
// Empirically found that it is faster to process 3 data items together vs 2 or 4
|
||||
for (; i <= n - 3 * Vec::size(); i += 3 * Vec::size()) {
|
||||
@ -519,7 +520,7 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
|
||||
if (i < n) {
|
||||
for (; i < n; i++) {
|
||||
scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
|
||||
*out_ptr = *(scalar_t *)(data[1] + i * stride);
|
||||
*out_ptr = c10::load((scalar_t *)(data[1] + i * stride));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -208,7 +208,7 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve
|
||||
data[arg] = data_[arg];
|
||||
}
|
||||
|
||||
Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0));
|
||||
Vec opt_scalar = Vec(S > 0 ? c10::load((scalar_t*)data[S]) : scalar_t(0));
|
||||
int64_t i = 0;
|
||||
for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
|
||||
auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
|
||||
|
@ -45,7 +45,7 @@ void cpu_pixel_shuffle(
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
int64_t input_offset = n * stride_n + c * stride_c + s1 * stride_s1 +
|
||||
s2 * stride_s2 + h * stride_h + w;
|
||||
output_data[i] = input_data[input_offset];
|
||||
output_data[i] = c10::load(&input_data[input_offset]);
|
||||
|
||||
data_index_step(n, nbatch, c, sub_channels, h, height, s1, S, w, width, s2, S);
|
||||
}
|
||||
@ -144,7 +144,7 @@ void cpu_pixel_unshuffle(
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
int64_t input_offset = n * stride_n + c * stride_c + h * stride_h +
|
||||
s1 * stride_s1 + w * stride_w + s2 * stride_s2;
|
||||
output_data[i] = input_data[input_offset];
|
||||
output_data[i] = c10::load(&input_data[input_offset]);
|
||||
|
||||
data_index_step(n, nbatch, c, sub_channels, s1, S, s2, S, h, height, w, width);
|
||||
}
|
||||
@ -186,7 +186,7 @@ void cpu_pixel_unshuffle_channels_last(
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
int64_t input_offset = n * stride_n + h * stride_h + s1 * stride_s1 +
|
||||
w * stride_w + s2 * stride_s2 + c * stride_c;
|
||||
output_data[i] = input_data[input_offset];
|
||||
output_data[i] = c10::load(&input_data[input_offset]);
|
||||
|
||||
data_index_step(n, nbatch, h, height, w, width, c, sub_channels, s1, S, s2, S);
|
||||
}
|
||||
|
@ -34,11 +34,11 @@ public:
|
||||
template <typename scalar_t>
|
||||
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
*self_data *= opmath_t(*src_data);
|
||||
*self_data *= opmath_t(c10::load(src_data));
|
||||
}
|
||||
|
||||
constexpr void operator() (bool * self_data, bool * src_data) const {
|
||||
*self_data = *self_data && *src_data;
|
||||
*self_data = c10::load(self_data) && c10::load(src_data);
|
||||
}
|
||||
};
|
||||
static ReduceMultiply reduce_multiply;
|
||||
@ -48,7 +48,7 @@ public:
|
||||
template <typename scalar_t>
|
||||
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
*self_data += opmath_t(*src_data);
|
||||
*self_data += opmath_t(c10::load(src_data));
|
||||
}
|
||||
};
|
||||
static ReduceAdd reduce_add;
|
||||
@ -58,7 +58,7 @@ public:
|
||||
template <typename scalar_t>
|
||||
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
*self_data += opmath_t(*src_data);
|
||||
*self_data += opmath_t(c10::load(src_data));
|
||||
}
|
||||
};
|
||||
static ReduceMean reduce_mean;
|
||||
@ -68,7 +68,9 @@ public:
|
||||
template <typename scalar_t>
|
||||
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
*self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::max(*self_data, opmath_t(*src_data));
|
||||
auto self_value = c10::load(self_data);
|
||||
auto src_value = c10::load(src_data);
|
||||
*self_data = at::_isnan<scalar_t>(src_value) ? opmath_t(src_value) : std::max(self_value, opmath_t(src_value));
|
||||
}
|
||||
};
|
||||
static ReduceMaximum reduce_maximum;
|
||||
@ -78,7 +80,9 @@ public:
|
||||
template <typename scalar_t>
|
||||
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
*self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::min(*self_data, opmath_t(*src_data));
|
||||
auto self_value = c10::load(self_data);
|
||||
auto src_value = c10::load(src_data);
|
||||
*self_data = at::_isnan<scalar_t>(src_value) ? opmath_t(src_value) : std::min(self_value, opmath_t(src_value));
|
||||
}
|
||||
};
|
||||
static ReduceMinimum reduce_minimum;
|
||||
@ -88,7 +92,7 @@ public:
|
||||
template <typename scalar_t>
|
||||
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
*self_data = opmath_t(*src_data);
|
||||
*self_data = opmath_t(c10::load(src_data));
|
||||
}
|
||||
};
|
||||
static TensorAssign tensor_assign;
|
||||
|
@ -115,7 +115,7 @@ static void min_kernel_impl(
|
||||
scalar_t min_number = c10::load(self_data);
|
||||
int64_t index = 0;
|
||||
for (const auto i : c10::irange(self_dim_size)) {
|
||||
scalar_t value = self_data[i * self_dim_stride];
|
||||
scalar_t value = c10::load(&self_data[i * self_dim_stride]);
|
||||
if (!(zabs_(value) >= zabs_(min_number))) {
|
||||
min_number = value;
|
||||
index = i;
|
||||
|
@ -148,7 +148,7 @@ template <typename T>
|
||||
inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
|
||||
for (int64_t j = 0; j < N; j++) {
|
||||
for (int64_t i = 0; i < M; i++) {
|
||||
dst[j * ld_dst + i] = src[i * ld_src + j];
|
||||
dst[j * ld_dst + i] = c10::load(&(src[i * ld_src + j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ static void im2col(
|
||||
int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
|
||||
data_col[(c_col * height_col + h_col) * width_col + w_col] =
|
||||
(h_im >= 0 && w_im >= 0 && h_im < height && w_im < width)
|
||||
? data_im[(c_im * height + h_im) * width + w_im]
|
||||
? c10::load(&(data_im[(c_im * height + h_im) * width + w_im]))
|
||||
: static_cast<T>(0);
|
||||
}
|
||||
}
|
||||
|
@ -26,12 +26,12 @@ struct LoadImpl<bool> {
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE T load(const void* src) {
|
||||
C10_HOST_DEVICE constexpr T load(const void* src) {
|
||||
return c10::detail::LoadImpl<T>::apply(src);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
C10_HOST_DEVICE scalar_t load(const scalar_t* src) {
|
||||
C10_HOST_DEVICE constexpr scalar_t load(const scalar_t* src) {
|
||||
return c10::detail::LoadImpl<scalar_t>::apply(src);
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,6 @@ from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_ROCM,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TEST_WITH_TORCHINDUCTOR,
|
||||
TEST_WITH_UBSAN,
|
||||
TestCase,
|
||||
unMarkDynamoStrictTest,
|
||||
)
|
||||
@ -653,16 +652,6 @@ class TestCommon(TestCase):
|
||||
and dtype == torch.float16
|
||||
):
|
||||
self.skipTest("Skipped on ROCm")
|
||||
# skip zero-dim tensors for some composites of reduction operations and view
|
||||
skip_zero_dim_ops = [
|
||||
"_refs.logsumexp",
|
||||
"_refs.log_softmax",
|
||||
"_refs.native_group_norm",
|
||||
"_refs.softmax",
|
||||
"_refs.sum_to_size",
|
||||
"ops.nvprims.view",
|
||||
]
|
||||
|
||||
from copy import copy
|
||||
|
||||
from torch._prims.executor import make_traced
|
||||
@ -1050,7 +1039,7 @@ class TestCommon(TestCase):
|
||||
try:
|
||||
info = torch.iinfo(t.dtype)
|
||||
return torch.full_like(t, info.max)
|
||||
except TypeError as te:
|
||||
except TypeError:
|
||||
# for non-integer types fills with NaN
|
||||
return torch.full_like(t, float("nan"))
|
||||
|
||||
@ -1445,7 +1434,6 @@ class TestCommon(TestCase):
|
||||
self.assertEqual(actual, expected, exact_dtype=False)
|
||||
|
||||
@ops(op_db, allowed_dtypes=(torch.bool,))
|
||||
@unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior")
|
||||
def test_non_standard_bool_values(self, device, dtype, op):
|
||||
# Test boolean values other than 0x00 and 0x01 (gh-54789)
|
||||
def convert_boolean_tensors(x):
|
||||
@ -2754,7 +2742,7 @@ class TestFakeTensor(TestCase):
|
||||
|
||||
try:
|
||||
op(input, *args, **kwargs)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
with TestPointwiseMode():
|
||||
|
Reference in New Issue
Block a user