mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Various fixes of torch/csrc files (#127252)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/127252 Approved by: https://github.com/r-barnes
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <system_error>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ops/from_blob.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
@ -268,32 +269,30 @@ void THPStorage_writeFileRaw(
|
||||
doWrite(fd, data, size_bytes);
|
||||
} else {
|
||||
size_t buffer_size = std::min(numel, (size_t)5000);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
std::unique_ptr<uint8_t[]> le_buffer(
|
||||
new uint8_t[buffer_size * element_size]);
|
||||
std::vector<uint8_t> le_buffer;
|
||||
le_buffer.resize(buffer_size * element_size);
|
||||
for (size_t i = 0; i < numel; i += buffer_size) {
|
||||
size_t to_convert = std::min(numel - i, buffer_size);
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
if (element_size == 2) {
|
||||
torch::utils::THP_encodeInt16Buffer(
|
||||
(uint8_t*)le_buffer.get(),
|
||||
le_buffer.data(),
|
||||
(const int16_t*)data + i,
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
|
||||
to_convert);
|
||||
} else if (element_size == 4) {
|
||||
torch::utils::THP_encodeInt32Buffer(
|
||||
(uint8_t*)le_buffer.get(),
|
||||
le_buffer.data(),
|
||||
(const int32_t*)data + i,
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
|
||||
to_convert);
|
||||
} else if (element_size == 8) {
|
||||
torch::utils::THP_encodeInt64Buffer(
|
||||
(uint8_t*)le_buffer.get(),
|
||||
le_buffer.data(),
|
||||
(const int64_t*)data + i,
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
|
||||
to_convert);
|
||||
}
|
||||
doWrite(fd, le_buffer.get(), to_convert * element_size);
|
||||
doWrite(fd, le_buffer.data(), to_convert * element_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -99,8 +99,7 @@ void initializeDtypes() {
|
||||
|
||||
#define DEFINE_SCALAR_TYPE(_1, n) at::ScalarType::n,
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
at::ScalarType all_scalar_types[] = {
|
||||
auto all_scalar_types = {
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
|
||||
|
||||
for (at::ScalarType scalarType : all_scalar_types) {
|
||||
|
Reference in New Issue
Block a user