mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[2/N] Fix extra warnings brought by clang-tidy-17 (#137459)
Follows #137407 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137459 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -144,8 +144,8 @@ class CheckSparseTensorInvariants {
|
||||
bool old_state;
|
||||
|
||||
public:
|
||||
CheckSparseTensorInvariants(bool state) {
|
||||
old_state = at::globalContext().checkSparseTensorInvariants();
|
||||
CheckSparseTensorInvariants(bool state)
|
||||
: old_state(at::globalContext().checkSparseTensorInvariants()) {
|
||||
at::globalContext().setCheckSparseTensorInvariants(state);
|
||||
}
|
||||
|
||||
|
@ -82,7 +82,7 @@ class TORCH_API ThreadLocalState {
|
||||
!defined(BUILD_LITE_INTERPRETER)
|
||||
// TLS for autocast dtypes
|
||||
std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
autocast_dtypes_;
|
||||
autocast_dtypes_{};
|
||||
#endif
|
||||
|
||||
friend class ThreadLocalStateGuard;
|
||||
|
@ -125,7 +125,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
// due to the capture status being updated _after_ a capture had already started.
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, [this](cudaStream_t stream) {
|
||||
cudaStreamCaptureStatus status;
|
||||
CaptureId_t stream_capture_id;
|
||||
CaptureId_t stream_capture_id = 0;
|
||||
AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id));
|
||||
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_;
|
||||
});
|
||||
|
@ -362,6 +362,7 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
|
||||
const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
|
||||
const c10::OptionalArrayRef<SymInt> bias_sizes_opt,
|
||||
c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed,
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
c10::SymIntArrayRef output_padding, c10::SymInt groups, std::array<bool, 3> output_mask) {
|
||||
const auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
vmap_check_escaped(maybe_layer, "convolution_backward_plumbing");
|
||||
|
@ -8,7 +8,7 @@
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/functorch/BatchRulesHelper.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
namespace at::functorch {
|
||||
|
||||
#define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
|
||||
#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
|
||||
@ -20,4 +20,4 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
OP_DECOMPOSE(_unsafe_masked_index_put_accumulate);
|
||||
}
|
||||
|
||||
}}
|
||||
}
|
||||
|
@ -226,7 +226,7 @@ static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes
|
||||
if (num_classes <= 0) {
|
||||
AT_ERROR("Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.push_back(num_classes);
|
||||
shape.emplace_back(num_classes);
|
||||
return at::empty_symint(shape, self.options());
|
||||
}
|
||||
}
|
||||
@ -246,7 +246,7 @@ static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes
|
||||
// TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
|
||||
// }
|
||||
|
||||
shape.push_back(num_classes);
|
||||
shape.emplace_back(num_classes);
|
||||
Tensor ret = at::zeros_symint(shape, self.options());
|
||||
return ret.scatter(-1, self.unsqueeze(-1), 1);
|
||||
}
|
||||
|
@ -213,7 +213,7 @@ static std::tuple<Tensor,Tensor> native_dropout_batching_rule(const Tensor& tens
|
||||
return std::make_tuple(output, mask);
|
||||
}
|
||||
|
||||
static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_samples, const bool replacement, const std::optional<Generator> generator) {
|
||||
static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_samples, const bool replacement, std::optional<Generator> generator) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
|
||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
const auto cur_level = maybe_layer->layerId();
|
||||
@ -237,7 +237,7 @@ static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_sa
|
||||
if (is_2D_case) {
|
||||
self_value = reshape_dim_into(0, 0, self_value);
|
||||
}
|
||||
auto out = multinomial(self_value, num_samples, replacement, generator);
|
||||
auto out = multinomial(self_value, num_samples, replacement, std::move(generator));
|
||||
if (is_2D_case) {
|
||||
out = reshape_dim_outof_symint(0, maybe_layer->batchSize(), out);
|
||||
}
|
||||
@ -249,7 +249,7 @@ static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_sa
|
||||
// Must be same randomness with unbatched input
|
||||
// 1D case: S -> multinomial(S) -> S
|
||||
// 2D case: MS -> multinomial(MS) -> MS
|
||||
return multinomial(self_value, num_samples, replacement, generator);
|
||||
return multinomial(self_value, num_samples, replacement, std::move(generator));
|
||||
}
|
||||
|
||||
template <typename A, A a, typename C>
|
||||
|
@ -102,7 +102,7 @@ static Tensor moveDimToFrontAndExpand(Tensor tensor, std::optional<int64_t> dim,
|
||||
} else {
|
||||
tensor = tensor.unsqueeze(0);
|
||||
auto expanded_sizes = tensor.sym_sizes().vec();
|
||||
expanded_sizes[0] = size;
|
||||
expanded_sizes[0] = std::move(size);
|
||||
tensor = tensor.expand_symint(expanded_sizes);
|
||||
}
|
||||
return tensor;
|
||||
|
@ -4,7 +4,6 @@
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/functorch/TensorWrapper.h>
|
||||
#include <ATen/functorch/BatchedTensorImpl.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
|
@ -9,8 +9,6 @@
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <ATen/native/nested/NestedTensorUtils.h>
|
||||
#include <ATen/native/nested/NestedTensorMath.h>
|
||||
#include <ATen/native/layer_norm.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
|
||||
#include <utility>
|
||||
|
@ -13,7 +13,6 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#include <ATen/native/layer_norm.h>
|
||||
#include <ATen/native/nested/NestedTensorUtils.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
#include <ATen/ops/narrow_native.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/native/NonSymbolicBC.h>
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
#include <ATen/native/nested/NestedTensorTransformerUtils.h>
|
||||
#include <ATen/native/nested/NestedTensorMath.h>
|
||||
|
@ -114,8 +114,7 @@ class DefaultMobileCPUAllocator final : public at::Allocator {
|
||||
}
|
||||
|
||||
auto alloc_size = PreGuardBytes + nbytes + PostGuardBytes;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
void* data;
|
||||
void* data = nullptr;
|
||||
auto allocator_ptr = GetThreadLocalCachingAllocator();
|
||||
auto profiling_allocator_ptr = GetThreadLocalProfilingAllocator();
|
||||
if (allocator_ptr != nullptr) {
|
||||
|
@ -88,8 +88,7 @@ static uint64_t readURandomLong() {
|
||||
* a 32 bit number to 64 bit.
|
||||
*/
|
||||
uint64_t getNonDeterministicRandom(bool is_cuda) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint64_t s;
|
||||
uint64_t s = 0;
|
||||
if (!is_cuda) {
|
||||
#ifdef _WIN32
|
||||
s = (uint64_t)std::chrono::high_resolution_clock::now()
|
||||
|
@ -186,7 +186,6 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
|
||||
return is_contiguous() | compute_non_overlapping_and_dense();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
void SymbolicShapeMeta::set_numel(SymInt val) const {
|
||||
std::scoped_lock lock(mutables_);
|
||||
if (has_numel()) {
|
||||
|
@ -111,7 +111,6 @@ TensorImpl::TensorImpl(
|
||||
DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type)
|
||||
: storage_(std::move(storage)),
|
||||
|
||||
numel_(0),
|
||||
data_type_(data_type),
|
||||
device_opt_(storage_.device()),
|
||||
@ -123,7 +122,6 @@ TensorImpl::TensorImpl(
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
TensorImpl::TensorImpl(
|
||||
DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type,
|
||||
@ -137,7 +135,6 @@ TensorImpl::TensorImpl(
|
||||
const caffe2::TypeMeta data_type,
|
||||
std::optional<c10::Device> device_opt)
|
||||
: storage_(std::move(storage)),
|
||||
|
||||
numel_(0),
|
||||
data_type_(data_type),
|
||||
device_opt_(device_opt) {
|
||||
|
@ -92,8 +92,7 @@ void* alloc_cpu(size_t nbytes) {
|
||||
"alloc_cpu() seems to have been called with negative number: ",
|
||||
nbytes);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
void* data;
|
||||
void* data = nullptr;
|
||||
#ifdef __ANDROID__
|
||||
data = memalign(gAlignment, nbytes);
|
||||
CAFFE_ENFORCE(
|
||||
|
@ -12,8 +12,7 @@ std::mutex CPUCachingAllocator::mutex_;
|
||||
ska::flat_hash_map<void*, size_t> CPUCachingAllocator::allocation_map_;
|
||||
|
||||
inline void* CPUCachingAllocator::allocate_and_cache(const size_t bytes) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
void* ptr;
|
||||
void* ptr = nullptr;
|
||||
try {
|
||||
ptr = c10::alloc_cpu(bytes);
|
||||
} catch (c10::Error&) {
|
||||
|
@ -152,10 +152,8 @@ std::vector<uint64_t> formulate_greedy_allocation_plan(
|
||||
create_and_sort_mem_events(allocation_sizes, allocation_lifetimes);
|
||||
uint64_t max_offset{0};
|
||||
for (const auto& mem_event : mem_events) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint64_t alloc_offset;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint64_t new_offset, new_size;
|
||||
uint64_t alloc_offset = 0;
|
||||
uint64_t new_offset = 0, new_size = 0;
|
||||
if (mem_event.type == EventType::Allocate) {
|
||||
auto it = free_size_to_offset.lower_bound(mem_event.size);
|
||||
if (it == free_size_to_offset.end()) {
|
||||
|
@ -41,17 +41,15 @@ float halfbits2float(unsigned short h) {
|
||||
unsigned short float2halfbits(float src) {
|
||||
unsigned x = c10::detail::fp32_to_bits(src);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables,cppcoreguidelines-avoid-magic-numbers)
|
||||
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
unsigned sign, exponent, mantissa;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
||||
unsigned u = (x & 0x7fffffff), shift = 0;
|
||||
|
||||
// Get rid of +NaN/-NaN case first.
|
||||
if (u > 0x7f800000) {
|
||||
return 0x7fffU;
|
||||
}
|
||||
|
||||
sign = ((x >> 16) & 0x8000);
|
||||
unsigned sign = ((x >> 16) & 0x8000);
|
||||
|
||||
// Get rid of +Inf/-Inf, +0/-0.
|
||||
if (u > 0x477fefff) {
|
||||
@ -61,8 +59,8 @@ unsigned short float2halfbits(float src) {
|
||||
return (sign | 0x0000);
|
||||
}
|
||||
|
||||
exponent = ((u >> 23) & 0xff);
|
||||
mantissa = (u & 0x7fffff);
|
||||
unsigned exponent = ((u >> 23) & 0xff);
|
||||
unsigned mantissa = (u & 0x7fffff);
|
||||
|
||||
if (exponent > 0x70) {
|
||||
shift = 13;
|
||||
@ -72,12 +70,12 @@ unsigned short float2halfbits(float src) {
|
||||
exponent = 0;
|
||||
mantissa |= 0x800000;
|
||||
}
|
||||
lsb = (1 << shift);
|
||||
lsb_s1 = (lsb >> 1);
|
||||
lsb_m1 = (lsb - 1);
|
||||
unsigned lsb = (1 << shift);
|
||||
unsigned lsb_s1 = (lsb >> 1);
|
||||
unsigned lsb_m1 = (lsb - 1);
|
||||
|
||||
// Round to nearest even.
|
||||
remainder = (mantissa & lsb_m1);
|
||||
unsigned remainder = (mantissa & lsb_m1);
|
||||
mantissa >>= shift;
|
||||
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
|
||||
++mantissa;
|
||||
|
@ -7,17 +7,14 @@
|
||||
|
||||
namespace {
|
||||
float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint32_t bytes;
|
||||
bytes = 0;
|
||||
uint32_t bytes = 0;
|
||||
bytes |= sign;
|
||||
bytes <<= 8;
|
||||
bytes |= exponent;
|
||||
bytes <<= 23;
|
||||
bytes |= fraction;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float res;
|
||||
float res = 0;
|
||||
std::memcpy(&res, &bytes, sizeof(res));
|
||||
return res;
|
||||
}
|
||||
@ -160,8 +157,7 @@ TEST(BFloat16Math, NextAfterZero) {
|
||||
}
|
||||
|
||||
float BinaryToFloat(uint32_t bytes) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float res;
|
||||
float res = 0;
|
||||
std::memcpy(&res, &bytes, sizeof(res));
|
||||
return res;
|
||||
}
|
||||
|
@ -353,7 +353,6 @@ using Arg = typename invoke_traits<Func>::template arg<i>::type;
|
||||
|
||||
template <typename Func, size_t... Is, bool release_gil>
|
||||
auto wrap_pybind_function_impl_(
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
|
||||
Func&& f,
|
||||
std::index_sequence<Is...>,
|
||||
std::bool_constant<release_gil>) {
|
||||
|
@ -482,8 +482,7 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
|
||||
return THPByteUtils_newReal(value);
|
||||
/* Slice index */
|
||||
} else if (PySlice_Check(index)) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
Py_ssize_t start, stop, slicelength, step;
|
||||
Py_ssize_t start = 0, stop = 0, slicelength = 0, step = 0;
|
||||
if (PySlice_Unpack(index, &start, &stop, &step) < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -554,8 +553,7 @@ static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
|
||||
storage_set(storage, nindex, rvalue);
|
||||
return 0;
|
||||
} else if (PySlice_Check(index)) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
Py_ssize_t start, stop, step;
|
||||
Py_ssize_t start = 0, stop = 0, step = 0;
|
||||
Py_ssize_t len = static_cast<Py_ssize_t>(storage.nbytes());
|
||||
if (PySlice_Unpack(index, &start, &stop, &step) < 0) {
|
||||
return -1;
|
||||
|
@ -313,7 +313,6 @@ static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) {
|
||||
THPObjectPtr _event_sync_required(Py_None);
|
||||
Py_INCREF(Py_None);
|
||||
if (storage.data()) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
auto shandle =
|
||||
c10::cuda::CUDACachingAllocator::shareIpcHandle(storage.mutable_data());
|
||||
_handle = PyBytes_FromStringAndSize(
|
||||
@ -470,8 +469,7 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
|
||||
}
|
||||
auto ipc_event_handle = reinterpret_cast<const cudaIpcEventHandle_t*>(
|
||||
s_ipc_event_handle.c_str());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
cudaEvent_t event;
|
||||
cudaEvent_t event = nullptr;
|
||||
cudaIpcOpenEventHandle(&event, *ipc_event_handle);
|
||||
C10_CUDA_CHECK(
|
||||
cudaStreamWaitEvent(c10::cuda::getCurrentCUDAStream(device), event, 0));
|
||||
|
@ -110,7 +110,6 @@ struct TensorDataContainer {
|
||||
// NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{},
|
||||
// {}})`), the innermost empty braced-init-list `{}` matches the default
|
||||
// constructor of the innermost `TensorDataContainer`.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
TensorDataContainer()
|
||||
: sizes_({0}),
|
||||
// NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g.
|
||||
@ -125,12 +124,9 @@ struct TensorDataContainer {
|
||||
scalar_type_(at::k##S), \
|
||||
type_(TensorDataContainerType::Scalar), \
|
||||
scalar_(value) {}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AT_FORALL_COMPLEX_TYPES(TENSOR)
|
||||
#undef TENSOR
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
TensorDataContainer(std::initializer_list<TensorDataContainer> init_list)
|
||||
: sizes_(),
|
||||
scalar_type_(init_list.begin()->scalar_type()),
|
||||
@ -157,7 +153,7 @@ struct TensorDataContainer {
|
||||
elem.scalar_type());
|
||||
}
|
||||
sizes_.reserve(first_elem.sizes().size() + 1);
|
||||
sizes_.push_back(init_list.size());
|
||||
sizes_.push_back(static_cast<int64_t>(init_list.size()));
|
||||
sizes_.insert(
|
||||
sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end());
|
||||
}
|
||||
@ -174,9 +170,7 @@ struct TensorDataContainer {
|
||||
tensor_ = at::tensor(values, at::dtype(scalar_type_).device(at::kCPU)); \
|
||||
} \
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AT_FORALL_COMPLEX_TYPES(TENSOR)
|
||||
#undef TENSOR
|
||||
|
||||
@ -194,9 +188,7 @@ struct TensorDataContainer {
|
||||
#define TENSOR(T, S) \
|
||||
TensorDataContainer(const std::vector<T>& values) \
|
||||
: TensorDataContainer(at::ArrayRef<T>(values)) {}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AT_FORALL_COMPLEX_TYPES(TENSOR)
|
||||
#undef TENSOR
|
||||
|
||||
@ -328,7 +320,7 @@ struct TensorDataContainer {
|
||||
" in its first dimension, but got Tensor with size ",
|
||||
tensor.sizes()[0],
|
||||
" in its first dimension");
|
||||
size_t index = 0;
|
||||
int64_t index = 0;
|
||||
for (const auto& elem : init_list_) {
|
||||
at::Tensor slice = tensor[index];
|
||||
elem.fill_tensor(slice);
|
||||
|
@ -133,8 +133,7 @@ inline Tensor embedding_bag(
|
||||
input_.dim());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int mode_enum;
|
||||
int mode_enum = 0;
|
||||
if (std::holds_alternative<enumtype::kSum>(mode)) {
|
||||
mode_enum = 0;
|
||||
} else if (std::holds_alternative<enumtype::kMean>(mode)) {
|
||||
|
@ -47,8 +47,7 @@ inline Tensor kl_div(
|
||||
const Tensor& target,
|
||||
KLDivFuncOptions::reduction_t reduction,
|
||||
bool log_target = false) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
torch::Reduction::Reduction reduction_enum;
|
||||
torch::Reduction::Reduction reduction_enum{};
|
||||
|
||||
if (std::holds_alternative<enumtype::kMean>(reduction)) {
|
||||
TORCH_WARN(
|
||||
|
@ -60,8 +60,7 @@ inline Tensor grid_sample(
|
||||
GridSampleFuncOptions::mode_t mode,
|
||||
GridSampleFuncOptions::padding_mode_t padding_mode,
|
||||
std::optional<bool> align_corners) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int64_t mode_enum, padding_mode_enum;
|
||||
int64_t mode_enum = 0, padding_mode_enum = 0;
|
||||
|
||||
if (std::holds_alternative<enumtype::kBilinear>(mode)) {
|
||||
mode_enum = 0;
|
||||
|
@ -7,8 +7,6 @@
|
||||
#include <torch/nn/pimpl.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace torch {
|
||||
namespace nn {
|
||||
|
||||
@ -104,11 +102,8 @@ class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
|
||||
|
||||
Tensor forward(const Tensor& input) {
|
||||
this->_check_input_dim(input);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
double exponential_average_factor;
|
||||
if (this->options.momentum() == std::nullopt) {
|
||||
exponential_average_factor = 0.0;
|
||||
} else {
|
||||
double exponential_average_factor = 0.0;
|
||||
if (this->options.momentum().has_value()) {
|
||||
exponential_average_factor = this->options.momentum().value();
|
||||
}
|
||||
|
||||
|
@ -70,10 +70,8 @@ class Embedding : public torch::nn::ModuleHolder<EmbeddingImpl> {
|
||||
embeddings.dim() == 2,
|
||||
"Embeddings parameter is expected to be 2-dimensional");
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int64_t rows, cols;
|
||||
rows = embeddings.size(0);
|
||||
cols = embeddings.size(1);
|
||||
auto rows = embeddings.size(0);
|
||||
auto cols = embeddings.size(1);
|
||||
|
||||
Embedding embedding(EmbeddingOptions(rows, cols)
|
||||
._weight(embeddings)
|
||||
@ -149,10 +147,8 @@ class EmbeddingBag : public torch::nn::ModuleHolder<EmbeddingBagImpl> {
|
||||
embeddings.dim() == 2,
|
||||
"Embeddings parameter is expected to be 2-dimensional");
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int64_t rows, cols;
|
||||
rows = embeddings.size(0);
|
||||
cols = embeddings.size(1);
|
||||
auto rows = embeddings.size(0);
|
||||
auto cols = embeddings.size(1);
|
||||
|
||||
EmbeddingBag embeddingbag(
|
||||
EmbeddingBagOptions(rows, cols)
|
||||
|
@ -9,9 +9,7 @@
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace data {
|
||||
namespace datasets {
|
||||
namespace torch::data::datasets {
|
||||
namespace {
|
||||
constexpr uint32_t kTrainSize = 60000;
|
||||
constexpr uint32_t kTestSize = 10000;
|
||||
@ -36,18 +34,20 @@ constexpr uint32_t flip_endianness(uint32_t value) {
|
||||
|
||||
uint32_t read_int32(std::ifstream& stream) {
|
||||
static const bool is_little_endian = check_is_little_endian();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint32_t value;
|
||||
uint32_t value = 0;
|
||||
AT_ASSERT(stream.read(reinterpret_cast<char*>(&value), sizeof value));
|
||||
return is_little_endian ? flip_endianness(value) : value;
|
||||
}
|
||||
|
||||
uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
|
||||
const auto value = read_int32(stream);
|
||||
// clang-format off
|
||||
TORCH_CHECK(value == expected,
|
||||
"Expected to read number ", expected, " but found ", value, " instead");
|
||||
// clang-format on
|
||||
TORCH_CHECK(
|
||||
value == expected,
|
||||
"Expected to read number ",
|
||||
expected,
|
||||
" but found ",
|
||||
value,
|
||||
" instead");
|
||||
return value;
|
||||
}
|
||||
|
||||
@ -101,14 +101,15 @@ MNIST::MNIST(const std::string& root, Mode mode)
|
||||
targets_(read_targets(root, mode == Mode::kTrain)) {}
|
||||
|
||||
Example<> MNIST::get(size_t index) {
|
||||
return {images_[index], targets_[index]};
|
||||
return {
|
||||
images_[static_cast<int64_t>(index)],
|
||||
targets_[static_cast<int64_t>(index)]};
|
||||
}
|
||||
|
||||
std::optional<size_t> MNIST::size() const {
|
||||
return images_.size(0);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
bool MNIST::is_train() const noexcept {
|
||||
return images_.size(0) == kTrainSize;
|
||||
}
|
||||
@ -121,6 +122,4 @@ const Tensor& MNIST::targets() const {
|
||||
return targets_;
|
||||
}
|
||||
|
||||
} // namespace datasets
|
||||
} // namespace data
|
||||
} // namespace torch
|
||||
} // namespace torch::data::datasets
|
||||
|
@ -176,8 +176,7 @@ std::tuple<double, Tensor> LBFGS::_directional_evaluate(
|
||||
double t,
|
||||
const Tensor& d) {
|
||||
_add_grad(t, d);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
double loss;
|
||||
double loss = 0;
|
||||
{
|
||||
torch::AutoGradMode enable_grad(true);
|
||||
loss = closure().item<double>();
|
||||
@ -215,12 +214,9 @@ static double _cubic_interpolate(
|
||||
|
||||
auto d1 = (g1 + g2) - (3 * (f1 - f2) / (x1 - x2));
|
||||
auto d2_square = std::pow(d1, 2) - g1 * g2;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
double d2;
|
||||
if (d2_square >= 0) {
|
||||
d2 = std::sqrt(d2_square);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
double min_pos;
|
||||
auto d2 = std::sqrt(d2_square);
|
||||
double min_pos = 0;
|
||||
if (x1 <= x2) {
|
||||
min_pos = x2 - ((x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)));
|
||||
} else {
|
||||
|
@ -8,8 +8,7 @@
|
||||
#include <torch/csrc/autograd/python_cpp_function.h>
|
||||
#include <torch/csrc/autograd/python_function.h>
|
||||
|
||||
// NOLINTNEXTLINE(misc-unused-alias-decls)
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace pybind11 {
|
||||
namespace detail {}
|
||||
} // namespace pybind11
|
||||
namespace pybind11::detail {} // namespace pybind11::detail
|
||||
|
@ -13,7 +13,6 @@
|
||||
#include <optional>
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::jit {
|
||||
|
@ -14,7 +14,7 @@ CUDAPluggableAllocatorDeleterContext::CUDAPluggableAllocatorDeleterContext(
|
||||
size_t size,
|
||||
int device,
|
||||
cudaStream_t stream)
|
||||
: free_fn_(free_fn),
|
||||
: free_fn_(std::move(free_fn)),
|
||||
data_(data),
|
||||
size_(size),
|
||||
device_(device),
|
||||
|
@ -1266,8 +1266,7 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
||||
m.def(
|
||||
"_tensors_data_ptrs_at_indices_equal",
|
||||
[](py::list& tensors, py::list& data_ptrs, py::list& indices) {
|
||||
for (size_t i = 0, end = indices.size(); i < end; ++i) {
|
||||
auto index = indices[i].cast<int64_t>();
|
||||
for (auto index : indices) {
|
||||
auto t = tensors[index].cast<at::Tensor>();
|
||||
auto data_ptr = data_ptrs[index].cast<int64_t>();
|
||||
if (reinterpret_cast<int64_t>(t.data_ptr()) != data_ptr) {
|
||||
@ -1451,7 +1450,6 @@ PyObject* THCPModule_getCurrentBlasHandle_wrap(
|
||||
PyObject* self,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
return PyLong_FromVoidPtr(handle);
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
@ -94,7 +94,6 @@ std::vector<Tensor>& broadcast_out(
|
||||
}
|
||||
|
||||
std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<Tensor> diff_device_dst_tensors;
|
||||
diff_device_dst_tensors.reserve(devices.size());
|
||||
for (auto device : devices) {
|
||||
@ -109,7 +108,6 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
|
||||
}
|
||||
}
|
||||
_broadcast_out_impl(tensor, diff_device_dst_tensors);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<Tensor> dst_tensors;
|
||||
dst_tensors.reserve(devices.size());
|
||||
auto it = diff_device_dst_tensors.begin();
|
||||
@ -172,7 +170,6 @@ tensor_list2d broadcast_coalesced(
|
||||
buffer_size = std::min(torch::cuda::nccl::get_max_count(), buffer_size);
|
||||
#endif
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
tensor_list2d outputs(devices.size());
|
||||
outputs[0] = tensors.vec();
|
||||
for (auto& o : outputs)
|
||||
@ -239,7 +236,6 @@ std::vector<at::Tensor>& scatter_out(
|
||||
"Expected at least one output tensor to scatter to");
|
||||
dim = at::maybe_wrap_dim(dim, tensor);
|
||||
int64_t total_size = 0;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<int64_t> chunk_sizes;
|
||||
chunk_sizes.reserve(out_tensors.size());
|
||||
for (const auto i : c10::irange(out_tensors.size())) {
|
||||
@ -374,7 +370,6 @@ static inline at::Tensor& _gather_out_impl(
|
||||
at::TensorList tensors,
|
||||
at::Tensor& out_tensor,
|
||||
int64_t dim) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<int64_t> chunk_sizes;
|
||||
chunk_sizes.reserve(tensors.size());
|
||||
for (auto& tensor : tensors) {
|
||||
@ -397,7 +392,6 @@ at::Tensor& gather_out(
|
||||
auto& first = tensors.front();
|
||||
const auto first_size = first.sizes();
|
||||
dim = at::maybe_wrap_dim(dim, first);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
|
||||
for (const auto i : c10::irange(tensors.size())) {
|
||||
const auto& tensor = tensors[i];
|
||||
@ -452,7 +446,6 @@ at::Tensor gather(
|
||||
auto& first = tensors.front();
|
||||
const auto first_size = first.sizes();
|
||||
dim = at::maybe_wrap_dim(dim, first);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<int64_t> expected_size(first_size.begin(), first_size.end());
|
||||
auto memory_format = first.suggest_memory_format();
|
||||
for (const auto i : c10::irange(tensors.size())) {
|
||||
|
@ -263,7 +263,7 @@ struct NcclCommList {
|
||||
~NcclCommList() {
|
||||
if (comms) {
|
||||
for (const auto i : c10::irange(ndevices)) {
|
||||
int dummy_var;
|
||||
int dummy_var = 0;
|
||||
if (C10_CUDA_ERROR_HANDLED(cudaGetDevice(&dummy_var)) != cudaSuccess) {
|
||||
/* there are cases when this destructor is called after the
|
||||
CUDA driver is already unloaded from the process.
|
||||
@ -366,7 +366,7 @@ void check_inputs(
|
||||
auto dtype = inputs[0].scalar_type();
|
||||
|
||||
for (const auto i : c10::irange(len)) {
|
||||
auto input = inputs[i];
|
||||
const auto& input = inputs[i];
|
||||
auto output = outputs[i];
|
||||
|
||||
check_tensor(
|
||||
@ -398,7 +398,7 @@ void check_inputs(
|
||||
auto dtype = inputs[0].scalar_type();
|
||||
|
||||
for (const auto i : c10::irange(len)) {
|
||||
auto input = inputs[i];
|
||||
const auto& input = inputs[i];
|
||||
|
||||
check_tensor(
|
||||
input,
|
||||
@ -421,25 +421,24 @@ void check_inputs(
|
||||
|
||||
} // namespace detail
|
||||
|
||||
AutoNcclGroup::AutoNcclGroup() {
|
||||
AutoNcclGroup::AutoNcclGroup() : comm_(nullptr), comm_nonblocking_(false) {
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
|
||||
// nccl < 2.0 cannot be called concurrently with cudaFree
|
||||
(c10::cuda::getFreeMutex())->lock();
|
||||
#endif
|
||||
comm_nonblocking_ = false;
|
||||
comm_ = nullptr;
|
||||
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
detail::NCCL_CHECK(ncclGroupStart());
|
||||
#endif
|
||||
}
|
||||
|
||||
AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) {
|
||||
AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking)
|
||||
: comm_(comm), comm_nonblocking_(comm_nonblocking) {
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2)
|
||||
// nccl < 2.0 cannot be called concurrently with cudaFree
|
||||
(c10::cuda::getFreeMutex())->lock();
|
||||
#endif
|
||||
comm_ = comm;
|
||||
comm_nonblocking_ = comm_nonblocking;
|
||||
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
|
||||
detail::NCCL_CHECK(ncclGroupStart());
|
||||
#endif
|
||||
@ -510,7 +509,7 @@ void get_unique_id(ncclUniqueId& id) {
|
||||
ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) {
|
||||
#ifdef USE_NCCL
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
ncclComm_t comm;
|
||||
ncclComm_t comm = nullptr;
|
||||
ncclUniqueId id = comm_id;
|
||||
NCCL_CHECK(ncclCommInitRank(
|
||||
to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank));
|
||||
@ -548,7 +547,7 @@ struct GetSecondArgType;
|
||||
|
||||
template <typename R, typename Arg0, typename Arg1, typename... Args>
|
||||
struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
|
||||
typedef typename std::decay<Arg1>::type type;
|
||||
typedef std::decay_t<Arg1> type;
|
||||
};
|
||||
|
||||
constexpr auto count_max =
|
||||
@ -827,7 +826,7 @@ void all2all_single_equal_split(
|
||||
((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7)))
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
|
||||
int numranks;
|
||||
int numranks = 0;
|
||||
auto type = to_nccl_data_type(input);
|
||||
size_t count = input.numel() / size;
|
||||
size_t rankdiff = input.nbytes() / size;
|
||||
@ -897,7 +896,7 @@ void all2all_single_unequal_split(
|
||||
comm,
|
||||
stream.stream()));
|
||||
#else
|
||||
int numranks;
|
||||
int numranks = 0;
|
||||
NCCL_CHECK(ncclCommCount(comm, &numranks));
|
||||
NCCL_CHECK(ncclGroupStart());
|
||||
for (const auto r : c10::irange(numranks)) {
|
||||
@ -1109,7 +1108,7 @@ void gather(
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
|
||||
auto comm = to_nccl_comm(_comm);
|
||||
int numranks, cur_rank;
|
||||
int numranks = 0, cur_rank = 0;
|
||||
NCCL_CHECK(ncclCommCount(comm, &numranks));
|
||||
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
|
||||
|
||||
@ -1158,7 +1157,7 @@ void scatter(
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
|
||||
auto comm = to_nccl_comm(_comm);
|
||||
int numranks, cur_rank;
|
||||
int numranks = 0, cur_rank = 0;
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
NCCL_CHECK(ncclCommCount(comm, &numranks));
|
||||
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
|
||||
|
@ -63,10 +63,8 @@ std::unique_ptr<PropagateGradientsReq> PropagateGradientsReq::fromMessage(
|
||||
bool retainGraph = tupleElements.back().toBool();
|
||||
|
||||
// Build AutogradMetadata.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int64_t autogradContextId, autogradMessageId;
|
||||
autogradMessageId = tupleElements[tupleElements.size() - 2].toInt();
|
||||
autogradContextId = tupleElements[tupleElements.size() - 3].toInt();
|
||||
int64_t autogradMessageId = tupleElements[tupleElements.size() - 2].toInt();
|
||||
int64_t autogradContextId = tupleElements[tupleElements.size() - 3].toInt();
|
||||
|
||||
AutogradMetadata autogradMetadata(autogradContextId, autogradMessageId);
|
||||
|
||||
|
@ -820,7 +820,7 @@ bool SocketConnectOp::tryConnect(int family) {
|
||||
|
||||
deadline_ = Clock::now() + opts_->connect_timeout();
|
||||
|
||||
bool retry; // NOLINT(cppcoreguidelines-init-variables)
|
||||
bool retry = false;
|
||||
do {
|
||||
retry = false;
|
||||
|
||||
|
@ -45,8 +45,7 @@ static std::vector<std::string> splitString(
|
||||
const std::string& delim) {
|
||||
std::vector<std::string> tokens;
|
||||
size_t start = 0;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
size_t end;
|
||||
size_t end = 0;
|
||||
// Iterate through each delimiter
|
||||
while ((end = s.find(delim, start)) != std::string::npos) {
|
||||
tokens.emplace_back(s.substr(start, end - start));
|
||||
|
@ -19,8 +19,7 @@ class AOTInductorModelContainer {
|
||||
AOTInductorModelContainer(
|
||||
size_t num_models,
|
||||
const std::string& device_str,
|
||||
const std::optional<std::string>& cubin_dir = std::nullopt)
|
||||
: use_secondary_(false), constant_folded_(false) {
|
||||
const std::optional<std::string>& cubin_dir = std::nullopt) {
|
||||
constants_map_ = std::make_shared<ConstantMap>();
|
||||
constants_array_ = std::make_shared<std::vector<ConstantHandle>>();
|
||||
|
||||
@ -413,10 +412,10 @@ class AOTInductorModelContainer {
|
||||
// If true,
|
||||
// constants_map_secondary/constant_blob_secondary/constants_array_secondary
|
||||
// is being used.
|
||||
bool use_secondary_;
|
||||
bool use_secondary_{false};
|
||||
|
||||
// Determine whether we have ran constant folding
|
||||
bool constant_folded_;
|
||||
bool constant_folded_{false};
|
||||
|
||||
// Holds the mapping of constants to at::Tensor.
|
||||
// The underlying data of at::Tensor is in either constant_blob_ (for CUDA).
|
||||
|
@ -19,8 +19,6 @@
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include <ATen/ScalarOps.h>
|
||||
|
||||
namespace torch::lazy {
|
||||
namespace {
|
||||
|
||||
|
@ -13,8 +13,7 @@
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace torch::lazy {
|
||||
// Turn clang-format off, as we rely on the whole signature being on one line
|
||||
// for codegen.
|
||||
// clang-format off
|
||||
@ -120,5 +119,4 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_diagonal_scatter(const a
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_slice_scatter_symint(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional<c10::SymInt> start, ::std::optional<c10::SymInt> end, c10::SymInt step);
|
||||
TORCH_API std::vector<torch::lazy::Shape> compute_shape_as_strided_scatter_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional<c10::SymInt> storage_offset);
|
||||
// clang-format on
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -18,10 +18,10 @@
|
||||
#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
|
||||
#endif // FBCODE_CAFFE2 || OVRSOURCE
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace torch::lazy {
|
||||
|
||||
// TODO(whc) backend 'device' related APIs are not very clear, this code could
|
||||
// be simplified but it should probably be done together with
|
||||
@ -190,7 +190,7 @@ void initLazyBindings(PyObject* module) {
|
||||
return torch::lazy::getLTCForceFallback();
|
||||
});
|
||||
lazy.def("_set_force_fallback", [](std::string newval) {
|
||||
torch::lazy::getLTCForceFallback() = newval;
|
||||
torch::lazy::getLTCForceFallback() = std::move(newval);
|
||||
});
|
||||
lazy.def("_clear_ir_cache", []() { TrieCache::Get()->Clear(); });
|
||||
lazy.def("_dump_ir_cache", [](std::string filename) {
|
||||
@ -337,5 +337,4 @@ void initLazyBindings(PyObject* module) {
|
||||
#endif // USE_DEPLOY
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -3,10 +3,8 @@
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace torch::lazy {
|
||||
|
||||
TORCH_PYTHON_API void initLazyBindings(PyObject* module);
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -8,8 +8,7 @@
|
||||
#include <torch/csrc/utils/python_compat.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace torch::lazy {
|
||||
|
||||
std::optional<SourceLocation> GetPythonFrameTop() {
|
||||
if (!Py_IsInitialized()) {
|
||||
@ -51,5 +50,4 @@ std::vector<SourceLocation> GetPythonFrames() {
|
||||
return frames;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -4,12 +4,10 @@
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace torch::lazy {
|
||||
|
||||
std::optional<SourceLocation> TORCH_PYTHON_API GetPythonFrameTop();
|
||||
|
||||
std::vector<SourceLocation> TORCH_PYTHON_API GetPythonFrames();
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -5,8 +5,7 @@
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace torch::lazy {
|
||||
|
||||
DeviceData::DeviceData(std::shared_ptr<BackendData> data)
|
||||
: TsNode(
|
||||
@ -26,7 +25,7 @@ const DeviceData* DeviceData::Cast(const Node* node) {
|
||||
return NodeCast<DeviceData>(node);
|
||||
}
|
||||
|
||||
NodePtr DeviceData::Create(std::shared_ptr<BackendData> data) {
|
||||
NodePtr DeviceData::Create(const std::shared_ptr<BackendData>& data) {
|
||||
NodePtr node = ReuseOrMakeNode<DeviceData>(data);
|
||||
// ReuseOrMakeNode may return a reused node which has the same shape,
|
||||
// however, we need to replace the old data_ with the new one.
|
||||
@ -38,5 +37,4 @@ NodePtr DeviceData::Create(std::shared_ptr<BackendData> data) {
|
||||
return node;
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -4,8 +4,9 @@
|
||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
#include <utility>
|
||||
|
||||
namespace torch::lazy {
|
||||
|
||||
class TORCH_API DeviceData : public TsNode {
|
||||
public:
|
||||
@ -18,7 +19,7 @@ class TORCH_API DeviceData : public TsNode {
|
||||
// A DeviceData node can be reused if the shape matches,
|
||||
// but we will substitute the actual data_ pointer under
|
||||
// the hood.
|
||||
bool CanBeReused(std::shared_ptr<BackendData> data) const {
|
||||
bool CanBeReused(const std::shared_ptr<BackendData>& data) const {
|
||||
return data_->shape() == data->shape();
|
||||
}
|
||||
|
||||
@ -29,14 +30,14 @@ class TORCH_API DeviceData : public TsNode {
|
||||
}
|
||||
|
||||
void SetData(std::shared_ptr<BackendData> data) {
|
||||
data_ = data;
|
||||
data_ = std::move(data);
|
||||
}
|
||||
|
||||
static const DeviceData* Cast(const Node* node);
|
||||
|
||||
// To reuse IR nodes, use this method to create DeviceData nodes
|
||||
// instead of calling the constructor directly.
|
||||
static NodePtr Create(std::shared_ptr<BackendData> data);
|
||||
// instead of calling the constructor directconst ly.
|
||||
static NodePtr Create(const std::shared_ptr<BackendData>& data);
|
||||
|
||||
TSOpVector Lower(
|
||||
std::shared_ptr<torch::jit::GraphFunction> function,
|
||||
@ -46,5 +47,4 @@ class TORCH_API DeviceData : public TsNode {
|
||||
std::shared_ptr<BackendData> data_;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -1,7 +1,6 @@
|
||||
#include <torch/csrc/lazy/ts_backend/ops/generic.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace torch::lazy {
|
||||
|
||||
Generic::Generic(
|
||||
OpKind op,
|
||||
@ -32,5 +31,4 @@ Generic::Generic(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
|
||||
: TsNode(op, std::move(shape), num_outputs, hash_seed),
|
||||
hash_seed_(hash_seed) {}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -4,8 +4,7 @@
|
||||
|
||||
#include <torch/csrc/lazy/core/ir_builder.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace torch::lazy {
|
||||
|
||||
// Generic IR Node implementation for nodes which can simply be described by a
|
||||
// specific OpKind and a lowering function. IR nodes carrying
|
||||
@ -50,5 +49,4 @@ inline NodePtr GenericOp(
|
||||
op, operands, std::move(shape), num_outputs, hash_seed);
|
||||
}
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
#include <torch/csrc/lazy/ts_backend/ts_node.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
namespace torch::lazy {
|
||||
|
||||
// This IR was copied from code-generated output, but the entire _to_copy
|
||||
// operator cannot be trivially code genereated since it is only desirable to
|
||||
@ -123,5 +122,4 @@ class ToCopy : public torch::lazy::TsNode {
|
||||
std::optional<at::MemoryFormat> memory_format;
|
||||
};
|
||||
|
||||
} // namespace lazy
|
||||
} // namespace torch
|
||||
} // namespace torch::lazy
|
||||
|
@ -2,10 +2,8 @@
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
namespace torch {
|
||||
namespace multiprocessing {
|
||||
namespace torch::multiprocessing {
|
||||
|
||||
PyMethodDef* python_functions();
|
||||
|
||||
} // namespace multiprocessing
|
||||
} // namespace torch
|
||||
} // namespace torch::multiprocessing
|
||||
|
@ -12,8 +12,7 @@
|
||||
namespace {
|
||||
|
||||
static inline void swapBytes16(void* ptr) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint16_t output;
|
||||
uint16_t output = 0;
|
||||
memcpy(&output, ptr, sizeof(uint16_t));
|
||||
#if defined(_MSC_VER) && !defined(_DEBUG)
|
||||
output = _byteswap_ushort(output);
|
||||
@ -28,8 +27,7 @@ static inline void swapBytes16(void* ptr) {
|
||||
}
|
||||
|
||||
static inline void swapBytes32(void* ptr) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint32_t output;
|
||||
uint32_t output = 0;
|
||||
memcpy(&output, ptr, sizeof(uint32_t));
|
||||
#if defined(_MSC_VER) && !defined(_DEBUG)
|
||||
output = _byteswap_ulong(output);
|
||||
@ -46,8 +44,7 @@ static inline void swapBytes32(void* ptr) {
|
||||
}
|
||||
|
||||
static inline void swapBytes64(void* ptr) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint64_t output;
|
||||
uint64_t output = 0;
|
||||
memcpy(&output, ptr, sizeof(uint64_t));
|
||||
#if defined(_MSC_VER)
|
||||
output = _byteswap_uint64(output);
|
||||
@ -70,8 +67,7 @@ static inline void swapBytes64(void* ptr) {
|
||||
}
|
||||
|
||||
static inline uint16_t decodeUInt16(const uint8_t* data) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint16_t output;
|
||||
uint16_t output = 0;
|
||||
memcpy(&output, data, sizeof(uint16_t));
|
||||
return output;
|
||||
}
|
||||
@ -83,8 +79,7 @@ static inline uint16_t decodeUInt16ByteSwapped(const uint8_t* data) {
|
||||
}
|
||||
|
||||
static inline uint32_t decodeUInt32(const uint8_t* data) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint32_t output;
|
||||
uint32_t output = 0;
|
||||
memcpy(&output, data, sizeof(uint32_t));
|
||||
return output;
|
||||
}
|
||||
@ -96,8 +91,7 @@ static inline uint32_t decodeUInt32ByteSwapped(const uint8_t* data) {
|
||||
}
|
||||
|
||||
static inline uint64_t decodeUInt64(const uint8_t* data) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint64_t output;
|
||||
uint64_t output = 0;
|
||||
memcpy(&output, data, sizeof(uint64_t));
|
||||
return output;
|
||||
}
|
||||
@ -149,6 +143,7 @@ TORCH_API void THP_decodeBuffer<c10::Half, bool>(
|
||||
bool do_byte_swap,
|
||||
size_t len) {
|
||||
for (const auto i : c10::irange(len)) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
union {
|
||||
uint16_t x;
|
||||
c10::Half f;
|
||||
@ -191,6 +186,7 @@ TORCH_API void THP_decodeBuffer<float, bool>(
|
||||
bool do_byte_swap,
|
||||
size_t len) {
|
||||
for (const auto i : c10::irange(len)) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
union {
|
||||
uint32_t x;
|
||||
float f;
|
||||
@ -208,6 +204,7 @@ TORCH_API void THP_decodeBuffer<double, bool>(
|
||||
bool do_byte_swap,
|
||||
size_t len) {
|
||||
for (const auto i : c10::irange(len)) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
union {
|
||||
uint64_t x;
|
||||
double d;
|
||||
@ -225,10 +222,12 @@ TORCH_API void THP_decodeBuffer<c10::complex<float>, bool>(
|
||||
bool do_byte_swap,
|
||||
size_t len) {
|
||||
for (const auto i : c10::irange(len)) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
union {
|
||||
uint32_t x;
|
||||
float re;
|
||||
};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
union {
|
||||
uint32_t y;
|
||||
float im;
|
||||
@ -250,10 +249,12 @@ TORCH_API void THP_decodeBuffer<c10::complex<double>, bool>(
|
||||
bool do_byte_swap,
|
||||
size_t len) {
|
||||
for (const auto i : c10::irange(len)) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
union {
|
||||
uint64_t x;
|
||||
double re;
|
||||
};
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
union {
|
||||
uint64_t y;
|
||||
double im;
|
||||
|
@ -343,7 +343,7 @@ inline bool array_has_torch_function(PyObject* const* args, Py_ssize_t nargs) {
|
||||
}
|
||||
|
||||
PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg) {
|
||||
bool result; // NOLINT(cppcoreguidelines-init-variables)
|
||||
bool result = false;
|
||||
if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) {
|
||||
// Fast path:
|
||||
// If we know that we have a tuple or list, we can skip an INCREF and
|
||||
|
@ -39,7 +39,7 @@ void initThroughputBenchmarkBindings(PyObject* module) {
|
||||
const py::kwargs& kwargs) {
|
||||
// Depending on this being ScriptModule of nn.Module we will release
|
||||
// the GIL or not further down in the stack
|
||||
return self.runOnce(std::move(args), kwargs);
|
||||
return self.runOnce(args, kwargs);
|
||||
})
|
||||
.def(
|
||||
"benchmark",
|
||||
|
@ -177,8 +177,7 @@ inline bool THPUtils_unpackNumberAsBool(PyObject* obj) {
|
||||
return !(real_val == 0 && imag_val == 0);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int overflow;
|
||||
int overflow = 0;
|
||||
long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
|
||||
if (value == -1 && PyErr_Occurred()) {
|
||||
throw python_error();
|
||||
|
@ -52,7 +52,6 @@ bool is_numpy_dlpack_deleter_bugged() {
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
using namespace at;
|
||||
@ -68,8 +67,7 @@ bool is_numpy_available() {
|
||||
}
|
||||
// Try to get exception message, print warning and return false
|
||||
std::string message = "Failed to initialize NumPy";
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
PyObject *type, *value, *traceback;
|
||||
PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
|
||||
PyErr_Fetch(&type, &value, &traceback);
|
||||
if (auto str = value ? PyObject_Str(value) : nullptr) {
|
||||
if (auto enc_str = PyUnicode_AsEncodedString(str, "utf-8", "strict")) {
|
||||
@ -403,10 +401,8 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
|
||||
}
|
||||
|
||||
// Extract the `obj.__cuda_array_interface__['typestr']` attribute
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
ScalarType dtype;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int dtype_size_in_bytes;
|
||||
ScalarType dtype{};
|
||||
int dtype_size_in_bytes = 0;
|
||||
{
|
||||
PyObject* py_typestr = nullptr;
|
||||
if (PyDict_GetItemStringRef(cuda_dict, "typestr", &py_typestr) < 0) {
|
||||
@ -415,8 +411,7 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
|
||||
if (py_typestr == nullptr) {
|
||||
throw TypeError("attribute `typestr` must exist");
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
PyArray_Descr* descr;
|
||||
PyArray_Descr* descr = nullptr;
|
||||
TORCH_CHECK_VALUE(
|
||||
PyArray_DescrConverter(py_typestr, &descr), "cannot parse `typestr`");
|
||||
dtype = numpy_dtype_to_aten(descr->type_num);
|
||||
@ -429,8 +424,7 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
|
||||
}
|
||||
|
||||
// Extract the `obj.__cuda_array_interface__['data']` attribute
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
void* data_ptr;
|
||||
void* data_ptr = nullptr;
|
||||
{
|
||||
PyObject* py_data = nullptr;
|
||||
if (PyDict_GetItemStringRef(cuda_dict, "data", &py_data) < 0) {
|
||||
|
Reference in New Issue
Block a user