[14/N] Fix extra warnings brought by clang-tidy-17 (#141644)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141644
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-12-11 18:40:38 +00:00
committed by PyTorch MergeBot
parent be27dbf2b8
commit 24a5a2ef25
24 changed files with 89 additions and 65 deletions

View File

@ -92,8 +92,8 @@ class MatrixRef {
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
U&& Temporary) = delete;
/// Disallow accidental assignment from a temporary.

View File

@ -106,6 +106,7 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
namespace {
static cublasOperation_t _cublasOpFromChar(char op) {
// NOLINTNEXTLINE(bugprone-switch-missing-default-case)
switch (op) {
case 'n':
case 'N':

View File

@ -8,17 +8,16 @@
#include <cmath>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
using namespace torch::nn;
using namespace torch::optim;
template <typename OptimizerClass, typename Options>
bool test_optimizer_xor(Options options) {
static bool test_optimizer_xor(Options options) {
torch::manual_seed(0);
Sequential model(
@ -30,9 +29,9 @@ bool test_optimizer_xor(Options options) {
const int64_t kBatchSize = 200;
const int64_t kMaximumNumberOfEpochs = 3000;
OptimizerClass optimizer(model->parameters(), options);
OptimizerClass optimizer(model->parameters(), std::move(options));
float running_loss = 1;
double running_loss = 1;
int epoch = 0;
while (running_loss > 0.1) {
auto inputs = torch::empty({kBatchSize, 2});
@ -46,8 +45,8 @@ bool test_optimizer_xor(Options options) {
auto step = [&](OptimizerClass& optimizer,
Sequential model,
torch::Tensor inputs,
torch::Tensor labels) {
const torch::Tensor& inputs,
const torch::Tensor& labels) {
auto closure = [&]() {
optimizer.zero_grad();
auto x = model->forward(inputs);
@ -60,11 +59,10 @@ bool test_optimizer_xor(Options options) {
torch::Tensor loss = step(optimizer, model, inputs, labels);
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions)
running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
running_loss = running_loss * 0.99 + loss.item<double>() * 0.01;
if (epoch > kMaximumNumberOfEpochs) {
std::cout << "Loss is too high after epoch " << epoch << ": "
<< running_loss << std::endl;
<< running_loss << '\n';
return false;
}
epoch++;
@ -73,10 +71,10 @@ bool test_optimizer_xor(Options options) {
}
template <typename Parameters>
void assign_parameter(
static void assign_parameter(
const Parameters& parameters,
const char* name,
torch::Tensor new_tensor) {
const torch::Tensor& new_tensor) {
auto parameter = parameters[name];
parameter.set_requires_grad(false);
parameter.flatten().copy_(new_tensor);
@ -84,7 +82,7 @@ void assign_parameter(
}
template <typename OptimizerClass, typename Options>
void check_exact_values(
static void check_exact_values(
Options options,
std::vector<std::vector<torch::Tensor>> expected_parameters) {
const size_t kIterations = 1001;
@ -119,7 +117,7 @@ void check_exact_values(
assign_parameter(
parameters, "2.bias", torch::tensor({-0.0711}, torch::kFloat64));
auto optimizer = OptimizerClass(parameters.values(), options);
auto optimizer = OptimizerClass(parameters.values(), std::move(options));
torch::Tensor input =
torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, torch::kFloat64)
.reshape({3, 2});
@ -145,8 +143,7 @@ void check_exact_values(
expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64);
if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
std::cout << "Iteration " << i << ": " << computed
<< " != " << expected << " (parameter " << p << ")"
<< std::endl;
<< " != " << expected << " (parameter " << p << ")" << '\n';
ASSERT_TRUE(false);
}
}
@ -166,8 +163,7 @@ TEST(OptimTest, OptimizerAccessors) {
ASSERT_TRUE(options == options_);
// test for param_groups() with non-const reference return
auto& params_groups = optimizer.param_groups();
// NOLINTNEXTLINE(modernize-use-emplace)
params_groups.push_back(OptimizerParamGroup(params));
params_groups.emplace_back(params);
auto& params_1 = params_groups[1].params();
for (const auto i : c10::irange(params_1.size())) {
torch::equal(params[i], params_1[i]);
@ -204,7 +200,7 @@ TEST(OptimTest, OptimizerAccessors) {
struct MyOptimizerOptions
: public OptimizerCloneableOptions<MyOptimizerOptions> {
MyOptimizerOptions(double lr = 1.0) : lr_(lr){};
MyOptimizerOptions(double lr = 1.0) : lr_(lr) {}
TORCH_ARG(double, lr) = 1.0;
};
@ -216,18 +212,16 @@ TEST(OptimTest, OldInterface) {
}
explicit MyOptimizer(
std::vector<at::Tensor> params,
MyOptimizerOptions defaults = {})
: // NOLINTNEXTLINE(performance-move-const-arg)
Optimizer(
{std::move(OptimizerParamGroup(params))},
const MyOptimizerOptions& defaults = {})
: Optimizer(
std::move(params),
std::make_unique<MyOptimizerOptions>(defaults)) {}
};
std::vector<torch::Tensor> parameters = {
torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
{
MyOptimizer optimizer(parameters);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t size;
size_t size = 0;
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
ASSERT_EQ(size, parameters.size());
}
@ -235,8 +229,7 @@ TEST(OptimTest, OldInterface) {
std::vector<at::Tensor> params;
MyOptimizer optimizer(params);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t size;
size_t size = 0;
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
ASSERT_EQ(size, 0);
@ -255,8 +248,7 @@ TEST(OptimTest, OldInterface) {
Linear linear(3, 4);
MyOptimizer optimizer(linear->parameters());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t size;
size_t size = 0;
OLD_INTERFACE_WARNING_CHECK(size = optimizer.size());
ASSERT_EQ(size, linear->parameters().size());
}
@ -480,7 +472,7 @@ TEST(OptimTest, AddParameter_LBFGS) {
// Check whether the learning rate of the parameter groups in the optimizer are
// the same as the expected learning rates given in the epoch:learning rate map
void check_lr_change(
static void check_lr_change(
Optimizer& optimizer,
LRScheduler& lr_scheduler,
std::map<unsigned, double> expected_epoch_lrs) {
@ -512,7 +504,7 @@ void check_lr_change(
// Very similar to check_lr_change, but for ReduceLROnPlateauScheduler
// which does not inherit from LRScheduler and requires a metrics
// input to step().
void check_lr_change_for_reduce_on_plateau(
static void check_lr_change_for_reduce_on_plateau(
Optimizer& optimizer,
ReduceLROnPlateauScheduler& lr_scheduler,
std::map<unsigned, double> expected_epoch_lrs) {

View File

@ -36,6 +36,10 @@ struct CudaIPCGlobalEntities {
CudaIPCGlobalEntities() {
alive = true;
}
CudaIPCGlobalEntities(const CudaIPCGlobalEntities&) = delete;
CudaIPCGlobalEntities(CudaIPCGlobalEntities&&) = delete;
CudaIPCGlobalEntities& operator=(const CudaIPCGlobalEntities&) = delete;
CudaIPCGlobalEntities& operator=(CudaIPCGlobalEntities&&) = delete;
~CudaIPCGlobalEntities() {
CudaIPCSentDataLimbo_.collect();
safe_clean_current_file();
@ -202,6 +206,7 @@ CudaIPCSentData::~CudaIPCSentData() {
}
cuda_ipc_global_entities.sync_events_used_--;
}
// NOLINTNEXTLINE(bugprone-empty-catch)
} catch (...) { /* No throw */
}
#endif

View File

@ -30,7 +30,7 @@ using namespace torch;
PyObject* THPGeneratorClass = nullptr;
PyObject* THPGenerator_initDefaultGenerator(const at::Generator& cdata) {
PyObject* THPGenerator_initDefaultGenerator(at::Generator cdata) {
auto type = (PyTypeObject*)THPGeneratorClass;
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
if (!self)
@ -401,8 +401,7 @@ PyObject* THPGenerator_Wrap(const Generator& gen) {
return obj;
}
return THPGenerator_NewWithVar(
(PyTypeObject*)THPGeneratorClass, std::move(gen));
return THPGenerator_NewWithVar((PyTypeObject*)THPGeneratorClass, gen);
}
at::Generator THPGenerator_Unwrap(PyObject* state) {

View File

@ -14,7 +14,7 @@ struct THPGenerator {
// is borrowed. The caller should ensure that the at::Generator object lifetime
// last at least as long as the Python wrapper.
TORCH_PYTHON_API PyObject* THPGenerator_initDefaultGenerator(
const at::Generator& cdata);
at::Generator cdata);
#define THPGenerator_Check(obj) PyObject_IsInstance(obj, THPGeneratorClass)

View File

@ -159,6 +159,10 @@ class PyInterpreterHolder {
is_main_interpreter_(
at::impl::PythonOpRegistrationTrampoline::registerInterpreter(
impl_)) {}
PyInterpreterHolder(const PyInterpreterHolder&) = delete;
PyInterpreterHolder(PyInterpreterHolder&&) = delete;
PyInterpreterHolder& operator=(const PyInterpreterHolder&) = delete;
PyInterpreterHolder& operator=(PyInterpreterHolder&&) = delete;
// NB: intentionally leaks the PyInterpreter, as there may still be
// references to it that are live, living in objects that aren't being
// destructed while Python is being cleaned up.

View File

@ -1,14 +1,14 @@
#pragma once
#include <torch/csrc/Export.h>
#include <c10/core/Device.h>
#include <c10/macros/Export.h>
#include <cstddef>
#include <cstdint>
namespace torch::cuda {
/// Returns the number of CUDA devices available.
size_t TORCH_API device_count();
c10::DeviceIndex TORCH_API device_count();
/// Returns true if at least one CUDA device is available.
bool TORCH_API is_available();

View File

@ -37,6 +37,10 @@ class DataLoaderBase {
main_thread_dataset_(std::move(main_thread_dataset)),
sequencer_(new_sequencer()) {}
DataLoaderBase(const DataLoaderBase&) = delete;
DataLoaderBase(DataLoaderBase&&) = delete;
DataLoaderBase& operator=(const DataLoaderBase&) = delete;
DataLoaderBase& operator=(DataLoaderBase&&) = delete;
// NOLINTNEXTLINE(bugprone-exception-escape)
virtual ~DataLoaderBase() {
join();

View File

@ -21,6 +21,7 @@ class AnyValue {
/// behavior of move for `std::unique_ptr`.
AnyValue(AnyValue&&) = default;
AnyValue& operator=(AnyValue&&) = default;
~AnyValue() = default;
/// Copy construction and assignment is allowed.
AnyValue(const AnyValue& other) : content_(other.content_->clone()) {}
@ -89,6 +90,8 @@ class AnyValue {
: type_info(type_info_) {}
Placeholder(const Placeholder&) = default;
Placeholder(Placeholder&&) = default;
Placeholder& operator=(const Placeholder&) = delete;
Placeholder& operator=(Placeholder&&) = delete;
virtual ~Placeholder() = default;
virtual std::unique_ptr<Placeholder> clone() const {
TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`");

View File

@ -41,11 +41,6 @@ struct TORCH_API AdagradParamState
TORCH_ARG(int64_t, step) = 0;
public:
AdagradParamState() = default;
AdagradParamState(const AdagradParamState&) = default;
AdagradParamState& operator=(const AdagradParamState&) = default;
AdagradParamState(AdagradParamState&&) noexcept = default;
AdagradParamState& operator=(AdagradParamState&&) noexcept = default;
void serialize(torch::serialize::InputArchive& archive) override;
void serialize(torch::serialize::OutputArchive& archive) const override;
TORCH_API friend bool operator==(

View File

@ -85,6 +85,7 @@ class TORCH_API OptimizerParamGroup {
options_(
param_group.has_options() ? param_group.options().clone()
: nullptr) {}
OptimizerParamGroup(OptimizerParamGroup&& param_group) = default;
OptimizerParamGroup(std::vector<Tensor> params)
: params_(std::move(params)) {}
OptimizerParamGroup(
@ -94,6 +95,9 @@ class TORCH_API OptimizerParamGroup {
OptimizerParamGroup& operator=(const OptimizerParamGroup& param_group) =
delete;
OptimizerParamGroup& operator=(OptimizerParamGroup&& param_group) noexcept =
default;
~OptimizerParamGroup() = default;
bool has_options() const;
OptimizerOptions& options();
const OptimizerOptions& options() const;
@ -112,6 +116,8 @@ class TORCH_API Optimizer {
// `state_dict` / `load_state_dict` API to copy an optimizer instead.
Optimizer(const Optimizer& optimizer) = delete;
Optimizer(Optimizer&& optimizer) = default;
Optimizer& operator=(const Optimizer& optimizer) = delete;
Optimizer& operator=(Optimizer&& optimizer) = default;
explicit Optimizer(
const std::vector<OptimizerParamGroup>& param_groups,

View File

@ -32,6 +32,7 @@ class TORCH_API LRScheduler {
private:
void set_optimizer_lrs(const std::vector<double>& learning_rates);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
torch::optim::Optimizer& optimizer_;
};
} // namespace torch::optim

View File

@ -379,7 +379,7 @@ Value& OrderedDict<Key, Value>::insert(Key key, Value&& value) {
template <typename Key, typename Value>
void OrderedDict<Key, Value>::update(OrderedDict&& other) {
reserve(size() + other.size());
for (auto& item : other) {
for (auto&& item : std::move(other)) {
// We want to call `insert()` to prevent duplicate keys.
insert(std::move(item.key()), std::move(item.value()));
}

View File

@ -4,11 +4,9 @@
#include <c10/core/DeviceGuard.h>
#include <c10/util/irange.h>
#include <cstddef>
namespace torch::cuda {
size_t device_count() {
c10::DeviceIndex device_count() {
return at::detail::getCUDAHooks().deviceCount();
}
@ -54,7 +52,7 @@ void synchronize(int64_t device_index) {
TORCH_CHECK(is_available(), "No CUDA GPUs are available");
auto num_gpus = cuda::device_count();
TORCH_CHECK(
device_index < 0 || static_cast<size_t>(device_index) < num_gpus,
device_index < 0 || device_index < num_gpus,
"Device index out of range: ",
device_index);
at::detail::getCUDAHooks().deviceSynchronize(

View File

@ -89,6 +89,7 @@ struct WarnNotImplemented : public Node {
size_t num_outputs;
};
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list {
auto inputsLocal = std::move(inputs);
warnAutogradNotImplemented(op_name);

View File

@ -122,6 +122,9 @@ struct TORCH_API AutogradContext {
AutogradContext() = default;
AutogradContext(const AutogradContext& other) = delete;
AutogradContext& operator=(const AutogradContext& other) = delete;
AutogradContext(AutogradContext&& other) = delete;
AutogradContext& operator=(AutogradContext&& other) = delete;
~AutogradContext() = default;
/// Can be used to save non-variable data for `backward`.
ska::flat_hash_map<std::string, at::IValue> saved_data;

View File

@ -21,6 +21,7 @@ AccumulateGrad::AccumulateGrad(Variable variable_)
add_input_metadata(variable);
}
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
check_input_variables("AccumulateGrad", grads, 1, 0);

View File

@ -109,13 +109,13 @@ variable_list Gather::apply(variable_list&& inputs) {
}
std::vector<at::Tensor> tensors;
tensors.reserve(inputs.size());
for (auto& variable : inputs) {
if (unsqueeze_scalars) {
if (unsqueeze_scalars) {
tensors.reserve(inputs.size());
for (auto& variable : inputs) {
tensors.push_back(variable.view(1));
} else {
tensors.push_back(std::move(variable));
}
} else {
tensors = std::move(inputs);
}
// Disable the autograd during the actual computation

View File

@ -21,21 +21,21 @@ static variable_list CopyBackwards_apply_functional(
std::array<bool, 2> needs_input_grad,
const c10::TensorOptions& src_options) {
check_input_variables("CopyBackwards", grads, 1, -1, true);
auto grad = c10::MaybeOwned<at::Tensor>::borrowed(grads[0]);
auto& grad = std::move(grads)[0];
variable_list grad_inputs(2);
if (grad->defined()) {
if (grad.defined()) {
if (needs_input_grad[0]) {
grad_inputs[0] = at::zeros_like(*grad, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
grad_inputs[0] = at::zeros_like(grad, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
if (needs_input_grad[1]) {
// Handle R->C copies without raising a warning
const auto src_type = src_options.dtype().toScalarType();
if (!c10::isComplexType(src_type) && grad->is_complex()) {
grad = c10::MaybeOwned<at::Tensor>::owned(at::real(grads[0]));
if (!c10::isComplexType(src_type) && grad.is_complex()) {
grad = at::real(grad);
}
at::DeviceGuard device_guard(src_options.device());
grad_inputs[1] = grad->to(src_options);
grad_inputs[1] = grad.to(src_options);
}
}
return grad_inputs;
@ -87,7 +87,7 @@ inline variable_list CopySlices::apply_impl(
variable_list&& inputs,
const T& call_fn) {
check_input_variables("CopySlices", inputs, 1, -1, true);
auto& grad = inputs[0];
auto& grad = std::move(inputs)[0];
if (!grad.defined()) {
return variable_list(num_outputs());
}

View File

@ -1,16 +1,21 @@
#include <autograd/grad_mode.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <cstddef>
#include <stdexcept>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
#include <sstream>
#include <utility>
namespace torch::autograd {
variable_list wrap_outputs(
const variable_list& inputs,
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
tensor_list&& outputs,
const function_constructor& ctr) {
variable_list result;
@ -18,7 +23,8 @@ variable_list wrap_outputs(
if (!any_variable_requires_grad(inputs)) {
for (auto& output : outputs) {
if (output.defined()) {
result.push_back(make_variable(output, /*requires_grad=*/false));
result.push_back(
make_variable(std::move(output), /*requires_grad=*/false));
} else {
result.emplace_back();
}
@ -29,7 +35,7 @@ variable_list wrap_outputs(
for (auto& output : outputs) {
if (output.defined()) {
auto variable =
autograd::make_variable(output, /*requires_grad=*/false);
autograd::make_variable(std::move(output), /*requires_grad=*/false);
autograd::create_gradient_edge(variable, grad_fn);
result.push_back(std::move(variable));
} else {
@ -50,7 +56,7 @@ void check_input_variables(
if (required_args == -1) {
required_args = args;
}
if (inputs.size() != (size_t)args) {
if (inputs.size() != static_cast<size_t>(args)) {
std::stringstream ss;
ss << name << ": expected " << args << " arguments (got " << inputs.size();
ss << ")";

View File

@ -277,6 +277,10 @@ class ValueCache {
public:
ValueCache() = default;
ValueCache(const ValueCache&) = delete;
ValueCache& operator==(const ValueCache&) = delete;
ValueCache(ValueCache&&) = default;
ValueCache& operator==(ValueCache&&) = delete;
~ValueCache() = default;
template <CallType C>
void store(const typename Config<C>::key_t&, typename Config<C>::ephemeral_t);

View File

@ -883,7 +883,7 @@ inline Variable make_variable(
} else {
data_impl_copy->set_autograd_meta(nullptr);
}
return Variable(data_impl_copy);
return Variable(std::move(data_impl_copy));
}
}
return Variable();

View File

@ -20,6 +20,7 @@ class TORCH_API DistAutogradContext {
using GradCallback = std::function<bool(torch::Tensor&)>;
explicit DistAutogradContext(int64_t contextId);
~DistAutogradContext() = default;
// Retrieves the autograd context id for this context.
int64_t contextId() const;