Various fixes of torch/csrc files (#127252)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127252
Approved by: https://github.com/r-barnes
This commit is contained in:
cyy
2024-06-14 17:31:21 +00:00
committed by PyTorch MergeBot
parent 089e76cca3
commit d4807da802
2 changed files with 8 additions and 10 deletions

View File

@ -1,5 +1,6 @@
#include <torch/csrc/python_headers.h>
#include <system_error>
#include <vector>
#include <ATen/ops/from_blob.h>
#include <c10/core/CPUAllocator.h>
@ -268,32 +269,30 @@ void THPStorage_writeFileRaw(
doWrite(fd, data, size_bytes);
} else {
size_t buffer_size = std::min(numel, (size_t)5000);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
std::unique_ptr<uint8_t[]> le_buffer(
new uint8_t[buffer_size * element_size]);
std::vector<uint8_t> le_buffer;
le_buffer.resize(buffer_size * element_size);
for (size_t i = 0; i < numel; i += buffer_size) {
size_t to_convert = std::min(numel - i, buffer_size);
// NOLINTNEXTLINE(bugprone-branch-clone)
if (element_size == 2) {
torch::utils::THP_encodeInt16Buffer(
(uint8_t*)le_buffer.get(),
le_buffer.data(),
(const int16_t*)data + i,
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
} else if (element_size == 4) {
torch::utils::THP_encodeInt32Buffer(
(uint8_t*)le_buffer.get(),
le_buffer.data(),
(const int32_t*)data + i,
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
} else if (element_size == 8) {
torch::utils::THP_encodeInt64Buffer(
(uint8_t*)le_buffer.get(),
le_buffer.data(),
(const int64_t*)data + i,
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
}
doWrite(fd, le_buffer.get(), to_convert * element_size);
doWrite(fd, le_buffer.data(), to_convert * element_size);
}
}
}

View File

@ -99,8 +99,7 @@ void initializeDtypes() {
#define DEFINE_SCALAR_TYPE(_1, n) at::ScalarType::n,
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
at::ScalarType all_scalar_types[] = {
auto all_scalar_types = {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
for (at::ScalarType scalarType : all_scalar_types) {