More fixes and improved clang-tidy checkers (#93213)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93213
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2023-02-01 14:44:13 +00:00
committed by PyTorch MergeBot
parent 679e869af0
commit 37f7c00a8a
37 changed files with 91 additions and 100 deletions

View File

@ -3,11 +3,14 @@
InheritParentConfig: true
Checks: '
bugprone-*,
-bugprone-easily-swappable-parameters,
-bugprone-forward-declaration-namespace,
-bugprone-macro-parentheses,
-bugprone-lambda-function-name,
-bugprone-reserved-identifier,
-bugprone-swapped-arguments,
cppcoreguidelines-*,
-cppcoreguidelines-avoid-do-while,
-cppcoreguidelines-avoid-magic-numbers,
-cppcoreguidelines-avoid-non-const-global-variables,
-cppcoreguidelines-interfaces-global-init,
@ -30,6 +33,7 @@ misc-unused-alias-decls,
misc-unused-using-decls,
modernize-*,
-modernize-concat-nested-namespaces,
-modernize-macro-to-enum,
-modernize-return-braced-init-list,
-modernize-use-auto,
-modernize-use-default-member-init,
@ -44,5 +48,4 @@ readability-container-size-empty,
HeaderFilterRegex: '^(c10/(?!test)|torch/csrc/(?!deploy/interpreter/cpython)).*$'
AnalyzeTemporaryDtors: false
WarningsAsErrors: '*'
CheckOptions:
...

View File

@ -858,7 +858,7 @@ auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t
template <typename T>
auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> {
return _register_hook(std::move(hook));
return _register_hook(std::forward<T>(hook));
}
namespace detail {

View File

@ -64,7 +64,7 @@ const AnnotatedKernel& OperatorEntry::ambiguousAutogradOtherKernel() const {
return kernel;
}
void OperatorEntry::assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const {
void OperatorEntry::assertSignatureIsCorrect(const CppSignature& call_signature, bool has_symint) const {
if (has_symint) {
if (C10_UNLIKELY(sym_cpp_signature_.has_value() && (call_signature != sym_cpp_signature_->signature))) {
reportSignatureError(call_signature, *sym_cpp_signature_);

View File

@ -167,7 +167,7 @@ public:
assertSignatureIsCorrect(CppSignature::make<FuncType>(), fn_has_symint<FuncType>::value);
}
void assertSignatureIsCorrect(const CppSignature call_signature, bool has_symint) const;
void assertSignatureIsCorrect(const CppSignature& call_signature, bool has_symint) const;
[[noreturn]] void reportError(DispatchKey dispatchKey) const;

View File

@ -80,7 +80,7 @@ struct StreamData3Holder : c10::intrusive_ptr_target {
StreamData3Holder(struct c10::StreamData3 d) {
val = d;
}
StreamData3Holder() = default;
StreamData3Holder() = delete;
struct c10::StreamData3 val;
};
@ -1261,12 +1261,12 @@ public:
friend MaybeOwnedTraits<IValue>;
Payload payload;
Tag tag;
Tag tag{IValue::Tag::None};
friend struct WeakIValue;
};
struct TORCH_API WeakIValue final {
WeakIValue() : tag(IValue::Tag::None), is_intrusive_ptr(false) {}
WeakIValue() = default;
WeakIValue(const WeakIValue& rhs)
: payload(rhs.payload),
@ -1378,8 +1378,8 @@ struct TORCH_API WeakIValue final {
private:
using Payload = IValue::Payload::TriviallyCopyablePayload;
Payload payload;
IValue::Tag tag;
bool is_intrusive_ptr;
IValue::Tag tag{IValue::Tag::None};
bool is_intrusive_ptr{false};
};
// An owning pointer to a type. When the type is class type, it requires a pair

View File

@ -1001,8 +1001,8 @@ struct TORCH_API DictType : public SharedType {
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
ss << "Dict[" << getKeyType()->annotation_str(printer) << ", "
<< getValueType()->annotation_str(std::move(printer)) << "]";
ss << "Dict[" << getKeyType()->annotation_str(printer) << ", ";
ss << getValueType()->annotation_str(std::move(printer)) << "]";
return ss.str();
}

View File

@ -350,14 +350,12 @@ namespace {
// /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L294-L312
VmapDimVector compute_indexed_shape(const Tensor &src, TensorList indices_list)
{
int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
int64_t dims_before = 0, dims_indexed = 0;
IntArrayRef replacement_shape;
for (const auto dim : c10::irange(indices_list.size())) {
if (!indices_list[dim].defined()) {
if (dims_indexed == 0) {
dims_before++;
} else {
dims_after++;
}
} else {
dims_indexed++;

View File

@ -153,7 +153,7 @@ class CacheEntry {
// Includes sampling callbacks which are waiting to run.
c10::SmallVector<CallbackAndCounter, kSoftLimitCallbacks> callbacks_;
RecordScope scope_;
RecordScope scope_{RecordScope::FUNCTION};
StepCallbacks active_callbacks_;

View File

@ -207,7 +207,11 @@ void ProfiledCPUMemoryReporter::New(void* ptr, size_t nbytes) {
}
if (profile_memory) {
reportMemoryUsageToProfiler(
ptr, nbytes, allocated, 0, c10::Device(c10::DeviceType::CPU));
ptr,
static_cast<int64_t>(nbytes),
static_cast<int64_t>(allocated),
0,
c10::Device(c10::DeviceType::CPU));
}
}
@ -242,7 +246,11 @@ void ProfiledCPUMemoryReporter::Delete(void* ptr) {
}
if (profile_memory) {
reportMemoryUsageToProfiler(
ptr, -nbytes, allocated, 0, c10::Device(c10::DeviceType::CPU));
ptr,
-static_cast<int64_t>(nbytes),
static_cast<int64_t>(allocated),
0,
c10::Device(c10::DeviceType::CPU));
}
}

View File

@ -130,7 +130,7 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) {
try {
if (!device_index_str.empty()) {
index_ = c10::stoi(device_index_str);
index_ = static_cast<c10::DeviceIndex>(c10::stoi(device_index_str));
}
} catch (const std::exception&) {
TORCH_CHECK(

View File

@ -104,6 +104,7 @@ TensorImpl::TensorImpl(
// the Python and PythonTLSSnapshot dispatch keys will be set and all is well.
// The point is to delay the dispatch key setting until that point.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl::TensorImpl(
ImplType type,
Storage&& storage,
@ -122,12 +123,14 @@ TensorImpl::TensorImpl(
}
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl::TensorImpl(
DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
c10::optional<c10::Device> device_opt)
: TensorImpl({}, key_set, data_type, device_opt) {}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl::TensorImpl(
Storage&& storage,
DispatchKeySet key_set,
@ -864,7 +867,8 @@ void TensorImpl::Extend(int64_t num, float growthPct) {
newCapacity[0] = std::max(
newDims[0],
static_cast<int64_t>(std::ceil(
sizes_and_strides_.size_at_unchecked(0) * (1 + growthPct / 100))));
static_cast<float>(sizes_and_strides_.size_at_unchecked(0)) *
(1 + growthPct / 100))));
auto oldData = std::move(storage_.data_ptr());
auto oldSize = numel_;
Resize(std::move(newCapacity));

View File

@ -26,6 +26,7 @@ PyInterpreter* PyObjectSlot::pyobj_interpreter() {
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
}
@ -47,10 +48,12 @@ PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
}
bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
}
void PyObjectSlot::set_owns_pyobj(bool b) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
pyobj_ = reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
}

View File

@ -41,7 +41,7 @@ const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_stack_at(
}
int64_t TorchDispatchModeTLS::stack_len() {
return torchDispatchModeState.stack_.size();
return static_cast<int64_t>(torchDispatchModeState.stack_.size());
}
const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {

View File

@ -30,8 +30,8 @@ void memset_junk(void* data, size_t num) {
static constexpr int32_t kJunkPattern = 0x7fedbeef;
static constexpr int64_t kJunkPattern64 =
static_cast<int64_t>(kJunkPattern) << 32 | kJunkPattern;
int32_t int64_count = num / sizeof(kJunkPattern64);
int32_t remaining_bytes = num % sizeof(kJunkPattern64);
auto int64_count = num / sizeof(kJunkPattern64);
auto remaining_bytes = num % sizeof(kJunkPattern64);
int64_t* data_i64 = reinterpret_cast<int64_t*>(data);
for (const auto i : c10::irange(int64_count)) {
data_i64[i] = kJunkPattern64;

View File

@ -434,8 +434,7 @@ __device__ __attribute__((noinline)) __attribute__((weak)) void __assert_fail(
// Warning: __has_trivial_copy for GCC may not always detect the non-POD
// correctly. For example, T = std::unique_ptr may evaluate to true and be
// treated as POD. This can cause unexpected behavior.
#if defined(__GNUG__) && __GNUC__ < 5 && \
!(defined(__clang__) && defined(_LIBCPP_VERSION))
#if defined(__GNUG__) && __GNUC__ < 5 && !defined(__clang__)
#define C10_IS_TRIVIALLY_COPYABLE(T) __has_trivial_copy(T)
#else
#define C10_IS_TRIVIALLY_COPYABLE(T) std::is_trivially_copyable<T>::value

View File

@ -501,13 +501,8 @@ class arrayref_optional_base {
: storage_(v) {}
constexpr bool initialized() const noexcept {
typename storage::raw repr;
// Cast to void* to suppress GCC's -Wclass-memaccess.
memcpy(
static_cast<void*>(&repr),
static_cast<const void*>(&storage_),
sizeof(storage_));
return repr.p != nullptr || repr.sz == 0;
return storage_.uninitialized_.p != nullptr ||
storage_.uninitialized_.sz == 0;
}
void setInitialized(bool init) noexcept {

View File

@ -166,7 +166,7 @@ struct Dim : public py::base<Dim> {
return batchtensor_;
}
private:
int64_t size_;
int64_t size_{-1};
at::Tensor range_;
at::Tensor batchtensor_;
};

View File

@ -91,10 +91,10 @@ static PyObject* THPStorage_shareFilename(PyObject* _self, PyObject* noargs) {
"_share_filename_: only available on CPU");
auto self = (THPStorage*)_self;
c10::StorageImpl* storage = self->cdata;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
THManagedMapAllocator* ctx;
THManagedMapAllocator* ctx =
THManagedMapAllocator::fromDataPtr(storage->data_ptr());
// Storage is already in shared memory, just return a handle
if ((ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()))) {
if (ctx) {
// done
} else {
// TODO: retry on collision

View File

@ -146,7 +146,7 @@ struct TORCH_API AutogradContext {
// weak_ptr to avoid a refcycle. Since grad_fn_ owns this AutogradContext, it
// will always be alive when we want to use it.
std::weak_ptr<Node> grad_fn_;
bool has_freed_buffers_;
bool has_freed_buffers_{false};
void save_variables();

View File

@ -28,18 +28,11 @@
#include <utility>
#include <vector>
using at::ArrayRef;
using at::Backend;
using at::Device;
using at::DeviceGuard;
using at::Dimname;
using at::DimnameList;
using at::Generator;
using at::IntArrayRef;
using at::Layout;
using at::OptionalDeviceGuard;
using at::Scalar;
using at::ScalarType;
using at::Tensor;
using at::TensorList;
using at::TensorOptions;

View File

@ -97,8 +97,8 @@ struct TORCH_API TensorPipeRpcBackendOptions : public RpcBackendOptions {
"num_worker_threads must be positive, got ",
numWorkerThreads);
if (transports.has_value()) {
for (const std::string& transportName : transports.value()) {
if (this->transports.has_value()) {
for (const std::string& transportName : this->transports.value()) {
TORCH_CHECK(
TensorPipeTransportRegistry()->Has(transportName),
"Unknown transport: ",
@ -106,8 +106,8 @@ struct TORCH_API TensorPipeRpcBackendOptions : public RpcBackendOptions {
}
}
if (channels.has_value()) {
for (const std::string& channelName : channels.value()) {
if (this->channels.has_value()) {
for (const std::string& channelName : this->channels.value()) {
TORCH_CHECK(
TensorPipeChannelRegistry()->Has(channelName),
"Unknown channel: ",

View File

@ -26,7 +26,6 @@ using c10::ListType;
using c10::MemoryFormatType;
using c10::NoneType;
using c10::NumberType;
using c10::OptionalType;
using c10::QSchemeType;
using c10::QuantizerType;
using c10::RRefType;

View File

@ -658,15 +658,15 @@ struct TORCH_API RangeValue : SugaredValue {
}
private:
Value* start_;
Value* end_;
Value* step_;
Value* start_{};
Value* end_{};
Value* step_{};
// a flag to determine if it's a simple range() call with only end_ from
// arguments If true, we will not insert length calculation and index
// derivation nodes to simplify the graph and enable more possible
// optimizations
bool has_only_end_;
c10::optional<int64_t> static_len_ = c10::nullopt;
bool has_only_end_{};
c10::optional<int64_t> static_len_;
};
// Specialized Tree structure to matched against for special handling

View File

@ -179,9 +179,7 @@ inline void warn(const char* _reason, const char* _kind = nullptr) {
TORCH_API void setWarn(warn_fn_type fn);
struct TORCH_API NoWarn {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
NoWarn() : state(getTracingState()) {
// NOLINTNEXTLINE(*.cplusplus.UninitializedObject)
if (state) {
prev = state->warn;
state->warn = false;
@ -193,7 +191,7 @@ struct TORCH_API NoWarn {
}
}
std::shared_ptr<TracingState> state;
bool prev;
bool prev{false};
};
struct WithNestedTracingFrame {

View File

@ -10,11 +10,8 @@
namespace torch {
namespace jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::PyTorchStreamWriter;
using caffe2::serialize::ReadAdapterInterface;
const static BackportManager backportManager;

View File

@ -15,11 +15,9 @@
namespace torch {
namespace jit {
using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::PyTorchStreamWriter;
using caffe2::serialize::ReadAdapterInterface;
// Current support bytecode version
namespace {

View File

@ -16,7 +16,6 @@
#include <stack>
using ::c10::Dispatcher;
using ::c10::DispatchKey;
namespace torch {
namespace jit {
namespace onnx {

View File

@ -101,7 +101,8 @@ struct ArgumentSpec {
// https://github.com/zdevito/pytorch/commit/21e7200a0a0fc456bea2f10e95b1781f83933d10
// show overhead in extra refcounting along this path
const at::Tensor* t = reinterpret_cast<const at::Tensor*>(&input);
if ((arg.defined_ = t->defined())) {
arg.defined_ = t->defined();
if (arg.defined_) {
arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad();
arg.dim_ = t->dim();
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)

View File

@ -13,8 +13,6 @@
#include <fmt/format.h>
#include <stdexcept>
using at::Scalar;
using at::Tensor;
namespace dist_autograd = torch::distributed::autograd;
namespace dist_rpc = torch::distributed::rpc;

View File

@ -549,7 +549,7 @@ static void append_overloaded_arg(
}
}
if (class_not_seen_yet) {
int arg_index = overloaded_args->size();
auto arg_index = overloaded_args->size();
for (const auto j : c10::irange(arg_index)) {
if (PyObject_IsSubclass(
obj_type,
@ -565,7 +565,8 @@ static void append_overloaded_arg(
// add object to overloaded_args. If it's a subclass of another class
// we've already seen it will be inserted before the superclass,
// otherwise it will be inserted at the end of the array
overloaded_args->insert(overloaded_args->begin() + arg_index, obj);
overloaded_args->insert(
overloaded_args->begin() + static_cast<long>(arg_index), obj);
}
}
@ -1204,19 +1205,19 @@ std::string FunctionSignature::toString() const {
[[noreturn]] static void extra_args(
const FunctionSignature& signature,
Py_ssize_t nargs) {
const long max_pos_args = signature.max_pos_args;
const long min_args = signature.min_args;
const auto max_pos_args = signature.max_pos_args;
const auto min_args = signature.min_args;
const long nargs_ = nargs;
if (min_args != max_pos_args) {
throw TypeError(
"%s() takes from %ld to %ld positional arguments but %ld were given",
"%s() takes from %zu to %zu positional arguments but %ld were given",
signature.name.c_str(),
min_args,
max_pos_args,
nargs_);
}
throw TypeError(
"%s() takes %ld positional argument%s but %ld %s given",
"%s() takes %zu positional argument%s but %ld %s given",
signature.name.c_str(),
max_pos_args,
max_pos_args == 1 ? "" : "s",
@ -1302,7 +1303,7 @@ bool FunctionSignature::parse(
PyObject* kwargs,
PyObject* dst[], // NOLINT
bool raise_exception) {
size_t nargs = args ? PyTuple_GET_SIZE(args) : 0;
Py_ssize_t nargs = args ? PyTuple_GET_SIZE(args) : 0;
auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
size_t arg_pos = 0;
bool allow_varargs_intlist = false;
@ -1320,7 +1321,7 @@ bool FunctionSignature::parse(
}
}
if (nargs > max_pos_args && !allow_varargs_intlist) {
if (static_cast<size_t>(nargs) > max_pos_args && !allow_varargs_intlist) {
if (raise_exception) {
// foo() takes takes 2 positional arguments but 3 were given
extra_args(*this, nargs);
@ -1339,7 +1340,7 @@ bool FunctionSignature::parse(
for (auto& param : params) {
PyObject* obj = nullptr;
bool is_kwd = false;
if (arg_pos < nargs) {
if (arg_pos < static_cast<size_t>(nargs)) {
// extra positional args given after single positional IntArrayRef arg
if (param.keyword_only) {
if (raise_exception) {

View File

@ -35,7 +35,7 @@ static void recursive_apply(
int64_t dim,
PyObject* fn,
std::array<StridedData, N> strided_data) {
int64_t ndim = sizes.size();
int64_t ndim = static_cast<int64_t>(sizes.size());
if (dim == ndim) {
auto args = THPObjectPtr(PyTuple_New(N));
if (!args)

View File

@ -29,7 +29,7 @@ std::vector<TensorGroup> take_tensors(
tensor_size = tensor.numel() * tensor.element_size();
}
auto& type_group = groups[type_id(tensor)];
auto& type_group = groups[static_cast<int64_t>(type_id(tensor))];
type_group.tensors.push_back(tensor);
if (fine_grained) {

View File

@ -17,8 +17,8 @@ static PyObject* recursive_to_list(
IntArrayRef strides,
int64_t dim,
ScalarType scalarType,
int64_t elementSize) {
int64_t ndim = sizes.size();
size_t elementSize) {
int64_t ndim = static_cast<int64_t>(sizes.size());
if (dim == ndim) {
return torch::utils::load_scalar(data, scalarType);
}

View File

@ -33,19 +33,14 @@
#include <stdexcept>
#include <vector>
using at::Backend;
using at::Device;
using at::IntArrayRef;
using at::kCPU;
using at::kCUDA;
using at::kInt;
using at::kLong;
using at::Scalar;
using at::ScalarType;
using at::Storage;
using at::Tensor;
using at::TensorOptions;
using at::Type;
using c10::optional;
namespace torch {
@ -64,7 +59,7 @@ TensorOptions build_options(
return options;
}
void maybe_initialize_cuda(const Device device) {
void maybe_initialize_cuda(const Device& device) {
if (device.is_cuda()) {
torch::utils::cuda_lazy_init();
}
@ -103,7 +98,7 @@ std::vector<int64_t> compute_sizes(PyObject* seq, ScalarType scalar_type) {
if (length < 0)
throw python_error();
if (is_storage) {
length /= elementSize(scalar_type);
length /= static_cast<int64_t>(elementSize(scalar_type));
}
sizes.push_back(length);
if (sizes.size() > MAX_DIMS) {
@ -205,11 +200,11 @@ void recursive_store(
IntArrayRef strides,
int64_t dim,
ScalarType scalarType,
int elementSize,
size_t elementSize,
PyObject* obj) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data != nullptr);
int64_t ndim = sizes.size();
int64_t ndim = static_cast<int64_t>(sizes.size());
bool is_symfloat = torch::is_symfloat(obj);
bool is_symint = torch::is_symint(obj);
if (dim == ndim) {
@ -374,7 +369,7 @@ Tensor internal_new_from_data(
at::tracer::impl::NoTracerDispatchMode tracer_guard;
if (isStorage(data)) {
ScalarType storage_scalar_type;
ScalarType storage_scalar_type{ScalarType::Undefined};
bool is_typed_storage = false;
Storage storage =
createStorageGetType(data, storage_scalar_type, is_typed_storage);
@ -562,6 +557,7 @@ Tensor legacy_sparse_tensor_generic_ctor_new(
check_legacy_ctor_device(dispatch_key, deviceOptional);
return at::empty({0}, build_options(options, scalar_type, deviceOptional));
} else if (r.idx == 1) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
return at::unsafeTensorFromTH(cdata, true);
} else if (r.idx == 2) {
@ -608,9 +604,9 @@ c10::TensorOptions typeIdWithDefault(
int64_t device_idx,
c10::DispatchKey dispatch_key) {
auto options = dispatchKeyToTensorOptions(dispatch_key);
if (!r.isNone(device_idx)) {
if (!r.isNone(static_cast<int>(device_idx))) {
// TODO: This line doesn't seem to be exercised at all in tests
options = options.device(r.device(device_idx).type());
options = options.device(r.device(static_cast<int>(device_idx)).type());
}
return options;
}
@ -655,7 +651,7 @@ Tensor legacy_tensor_generic_ctor_new(
at::OptionalDeviceGuard device_guard(deviceOptional);
return at::empty({0}, build_options(options, scalar_type));
} else if (r.idx == 1) {
at::ScalarType storage_scalar_type;
at::ScalarType storage_scalar_type{at::ScalarType::Undefined};
bool is_typed_storage = false;
at::Storage storage = r.storage(0, storage_scalar_type, is_typed_storage);
if (storage_scalar_type != at::ScalarType::Undefined && is_typed_storage) {
@ -669,6 +665,7 @@ Tensor legacy_tensor_generic_ctor_new(
}
return new_with_storage(options, scalar_type, storage);
} else if (r.idx == 2) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
return at::unsafeTensorFromTH(cdata, true);
} else if (r.idx == 3) {
@ -786,9 +783,8 @@ Tensor indexing_tensor_from_data(
class CheckSparseTensorInvariantsContext {
public:
CheckSparseTensorInvariantsContext() {
state = at::globalContext().checkSparseTensorInvariants();
}
CheckSparseTensorInvariantsContext()
: state{at::globalContext().checkSparseTensorInvariants()} {}
~CheckSparseTensorInvariantsContext() {
at::globalContext().setCheckSparseTensorInvariants(state);
}

View File

@ -175,7 +175,7 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor, bool force /*=false*/) {
auto array = THPObjectPtr(PyArray_New(
&PyArray_Type,
prepared_tensor.dim(),
static_cast<int>(prepared_tensor.dim()),
sizes.data(),
dtype,
strides.data(),
@ -382,6 +382,7 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
}
// Extract the `obj.__cuda_array_interface__['typestr']` attribute
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
ScalarType dtype;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int dtype_size_in_bytes;

View File

@ -197,8 +197,8 @@ class class_ : public ::torch::detail::class_base {
GetterFunc getter_func,
SetterFunc setter_func,
std::string doc_string = "") {
torch::jit::Function* getter;
torch::jit::Function* setter;
torch::jit::Function* getter{};
torch::jit::Function* setter{};
auto wrapped_getter =
detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
@ -218,7 +218,7 @@ class class_ : public ::torch::detail::class_base {
const std::string& name,
GetterFunc getter_func,
std::string doc_string = "") {
torch::jit::Function* getter;
torch::jit::Function* getter{};
auto wrapped_getter =
detail::wrap_func<CurClass, GetterFunc>(std::move(getter_func));
@ -321,7 +321,7 @@ class class_ : public ::torch::detail::class_base {
c10::guts::infer_function_traits_t<std::decay_t<SetStateFn>>;
using SetStateArg = typename c10::guts::typelist::head_t<
typename SetStateTraits::parameter_types>;
auto setstate_wrapper = [set_state = std::move(set_state)](
auto setstate_wrapper = [set_state = std::forward<SetStateFn>(set_state)](
c10::tagged_capsule<CurClass> self,
SetStateArg&& arg) {
c10::intrusive_ptr<CurClass> classObj =

View File

@ -27,7 +27,7 @@ class THManagedMapAllocator : private THManagedMapAllocatorInit,
void close() override;
~THManagedMapAllocator() {
~THManagedMapAllocator() override {
close();
}