Use C10_UNUSED instead of (void)X (#137239)

Summary:
Auto-generated with
```
buck run //scripts/rbarnes/regex_multiline_replacer:regex_multiline_replacer -- --find '^(\s*for\s*\()(const.*\n)\s*\(void\)[A-Za-z]+;\s*//\s*Suppress.*\s*\n(.*)'  --replace '\1C10_UNUSED \2\3' `find caffe2/ -regex ".*\.\(cpp\|h\)"`
```

Differential Revision: D33432600

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137239
Approved by: https://github.com/Skylion007
This commit is contained in:
Richard Barnes
2024-10-15 14:32:59 +00:00
committed by PyTorch MergeBot
parent e7a4ad3b40
commit b7f798caa4
48 changed files with 90 additions and 143 deletions

View File

@ -241,7 +241,7 @@ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const fu
auto* b_batch_idx_ptr = data[0];
auto* a_batch_idx_ptr = data[1];
for (const auto elem C10_UNUSED : c10::irange(nelems)) {
for (C10_UNUSED const auto elem : c10::irange(nelems)) {
auto b_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);

View File

@ -875,12 +875,8 @@ TORCH_IMPL_FUNC(index_copy_out)
// See Note [Enabling Deterministic Operations]
if (result.is_cuda() && globalContext().deterministicAlgorithms()){
torch::List<std::optional<Tensor>> indices;
indices.reserve(dim + 1);
for (const auto i: c10::irange(dim)) {
(void)i;
indices.emplace_back();
}
indices.emplace_back(index);
indices.resize(dim + 1);
indices.set(dim, index);
result.index_put_(indices, source, false);
return;
}

View File

@ -150,13 +150,11 @@ static void upsample_bicubic2d_backward_out_frame(
opmath_t t_y;
guard_index_and_lambda(real_y, input_height, input_y, t_y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
opmath_t x_coeffs[4];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
opmath_t y_coeffs[4];
std::array<opmath_t, 4> x_coeffs;
std::array<opmath_t, 4> y_coeffs;
get_cubic_upsample_coefficients<opmath_t>(x_coeffs, t_x);
get_cubic_upsample_coefficients<opmath_t>(y_coeffs, t_y);
get_cubic_upsample_coefficients<opmath_t>(x_coeffs.data(), t_x);
get_cubic_upsample_coefficients<opmath_t>(y_coeffs.data(), t_y);
opmath_t out_value = out[output_y * output_width + output_x];
for (const auto ii : c10::irange(4)) {

View File

@ -395,7 +395,7 @@ struct Dist {
const scalar_t * t1_end = t1 + l1_size;
const scalar_t * t2_end = t2 + l2_size;
for (const auto l C10_UNUSED : c10::irange(d)) {
for (C10_UNUSED const auto l : c10::irange(d)) {
for (; t1 != t1_end; t1 += m, res += m) {
const Vec vec_t1 = Vec::loadu(t1, count);
Vec res_vec = Vec::loadu(res, count);

View File

@ -30,7 +30,7 @@ void _compute_linear_combination_cpu_kernel(
auto* RESTRICT in_ptr = data[1];
auto* RESTRICT coeff_ptr = data[2];
for (const auto elem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto elem : c10::irange(n)) {
auto* RESTRICT out_data = reinterpret_cast<scalar_t*>(out_ptr);
auto* RESTRICT in_data = reinterpret_cast<scalar_t*>(in_ptr);
using primitive_t = typename scalar_value_type<scalar_t>::type;

View File

@ -78,7 +78,7 @@ void cpu_take_put_kernel(
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
auto* iterated_data_bytes = data[0];
auto* index_data_bytes = data[1];
for (const auto elem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto elem : c10::irange(n)) {
auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
auto& iterated = *reinterpret_cast<scalar_t*>(iterated_data_bytes);
@ -203,7 +203,7 @@ void index_fill_kernel(
auto handle_nonzero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) {
auto* self_data_bytes = data[0];
auto* index_data_bytes = data[1];
for (const auto elem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto elem : c10::irange(n)) {
auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
TORCH_CHECK_INDEX(idx >= -self_dim_size && idx < self_dim_size,
@ -229,7 +229,7 @@ void index_fill_kernel(
if (idx < 0) {
idx += self_dim_size;
}
for (const auto elem C10_UNUSED: c10::irange(n)) {
for (C10_UNUSED const auto elem: c10::irange(n)) {
auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
self_data[idx * self_dim_stride] = fill_val;
@ -262,7 +262,7 @@ void index_copy_kernel(
auto* self_data_bytes = data[0];
auto* index_data_bytes = data[1];
auto* source_data_bytes = data[2];
for (const auto elem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto elem : c10::irange(n)) {
auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
auto* source_data = reinterpret_cast<scalar_t*>(source_data_bytes);
@ -285,7 +285,7 @@ void index_copy_kernel(
TORCH_CHECK_INDEX(idx >= 0 && idx < self_dim_size,
"index_copy_(): index ", idx, " is out of bounds for dimension ",
dim, " with size ", self_dim_size);
for (const auto elem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto elem : c10::irange(n)) {
auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
auto* source_data = reinterpret_cast<scalar_t*>(source_data_bytes);
@ -474,7 +474,7 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
constexpr auto stride = sizeof(scalar_t);
TORCH_INTERNAL_ASSERT(stride == -strides[0] && stride == strides[1]);
for (const auto j C10_UNUSED : c10::irange(size1)) {
for (C10_UNUSED const auto j : c10::irange(size1)) {
// vectorized loop with negative stride for output
char** C10_RESTRICT data_ = data_arr.data();
@ -543,7 +543,7 @@ void cpu_vflip_memcpy(at::TensorIterator& iter) {
TORCH_INTERNAL_ASSERT(strides[0] == strides[1]);
const int64_t stride = strides[0];
for (const auto j C10_UNUSED : c10::irange(size1)) {
for (C10_UNUSED const auto j : c10::irange(size1)) {
char** C10_RESTRICT data_ = data_arr.data();
int64_t n = size0;

View File

@ -70,7 +70,7 @@ inline void vectorized_reduction(char** data, int64_t n, int64_t stride,
template <typename F>
inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) {
for (const auto j C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto j : c10::irange(n)) {
f();
data[0] += strides[0];
data[1] += strides[1];

View File

@ -62,7 +62,7 @@ static inline void cpu_cum_base_kernel(const Tensor& result,
auto* result_data_bytes = data[0];
const auto* self_data_bytes = data[1];
for (const auto i C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto i : c10::irange(n)) {
f(
(scalar_t*)result_data_bytes, result_dim_stride,
(scalar_t*)self_data_bytes, self_dim_stride, init_val

View File

@ -215,7 +215,7 @@ struct cpu_scatter_gather_base_kernel {
// vs dim-TensorIterator loop order depending on
// whether dim is the last dimension
if (dim== buffer.dim() - 1) {
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
// dim loop is a separate code block
// for better performance
loop_func.template operator()<scalar_t, func_t>(
@ -232,7 +232,7 @@ struct cpu_scatter_gather_base_kernel {
for (const auto i : c10::irange(index_dim_size)) {
auto* self_data = self_data_bytes;
auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
int64_t idx_dim = *(int64_t*)index_data;
// we are not putting idx_dim in the error message because it disables
// loop optimization in clang-7
@ -306,7 +306,7 @@ struct cpu_scatter_gather_base_kernel {
// vs dim-TensorIterator loop order depending on
// whether dim is the last dimension
if (dim== buffer.dim() - 1) {
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
// dim loop is a separate code block
// for better performance
loop_func.template operator()<scalar_t, func_t>(
@ -327,7 +327,7 @@ struct cpu_scatter_gather_base_kernel {
auto* self_data = self_data_bytes;
auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
auto* src_data = src_data_bytes;
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
int64_t idx_dim = *(int64_t*)index_data;
// we are not putting idx_dim in the error message because it disables
// loop optimization in clang-7
@ -402,7 +402,7 @@ struct cpu_scatter_gather_base_kernel {
// vs dim-TensorIterator loop order depending on
// whether dim is the last dimension
if (dim== buffer.dim() - 1) {
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
// dim loop is a separate code block
// for better performance
loop_func.template operator()<scalar_t, ReduceMean>(
@ -423,7 +423,7 @@ struct cpu_scatter_gather_base_kernel {
auto* self_data = self_data_bytes;
auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
auto* src_data = src_data_bytes;
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
int64_t idx_dim = *(int64_t*)index_data;
// we are not putting idx_dim in the error message because it disables
// loop optimization in clang-7
@ -497,7 +497,7 @@ struct cpu_scatter_gather_base_kernel {
// vs dim-TensorIterator loop order depending on
// whether dim is the last dimension
if (dim== buffer.dim() - 1) {
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
// dim loop is a separate code block
// for better performance
loop_func.template operator()<scalar_t, ReduceMaximum>(
@ -518,7 +518,7 @@ struct cpu_scatter_gather_base_kernel {
auto* self_data = self_data_bytes;
auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
auto* src_data = src_data_bytes;
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
int64_t idx_dim = *(int64_t*)index_data;
// we are not putting idx_dim in the error message because it disables
// loop optimization in clang-7
@ -593,7 +593,7 @@ struct cpu_scatter_gather_base_kernel {
// vs dim-TensorIterator loop order depending on
// whether dim is the last dimension
if (dim== buffer.dim() - 1) {
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
// dim loop is a separate code block
// for better performance
loop_func.template operator()<scalar_t, ReduceMinimum>(
@ -614,7 +614,7 @@ struct cpu_scatter_gather_base_kernel {
auto* self_data = self_data_bytes;
auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
auto* src_data = src_data_bytes;
for (const auto nelem C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto nelem : c10::irange(n)) {
int64_t idx_dim = *(int64_t*)index_data;
// we are not putting idx_dim in the error message because it disables
// loop optimization in clang-7

View File

@ -53,7 +53,7 @@ void _dim_apply(
return;
}
for (const auto i C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto i : c10::irange(n)) {
f(
reinterpret_cast<scalar_t*>(values_data_bytes),
values_dim_stride,

View File

@ -83,7 +83,7 @@ static inline void compare_base_kernel(const Tensor& result1, const Tensor& resu
auto* result1_data_bytes = data[0];
auto* result2_data_bytes = data[1];
const auto* self_data_bytes = data[2];
for (const auto i C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto i : c10::irange(n)) {
f((scalar_t*)result1_data_bytes,
(scalar_t_2*)result2_data_bytes,
(scalar_t*)self_data_bytes,
@ -253,7 +253,7 @@ static void mode_kernel_impl(
std::vector<std::pair<scalar_t, int64_t>> elements(self_dim_size);
for (const auto k C10_UNUSED : c10::irange(n)) {
for (C10_UNUSED const auto k : c10::irange(n)) {
scalar_t* values_data = (scalar_t*)values_data_bytes;
int64_t* indices_data = (int64_t*)indices_data_bytes;
const scalar_t* self_data = (scalar_t*)self_data_bytes;

View File

@ -733,7 +733,7 @@ struct HelperInterpBase {
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;
for (const auto j C10_UNUSED : c10::irange(interp_size)) {
for (C10_UNUSED const auto j : c10::irange(interp_size)) {
output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
output.emplace_back(empty(new_shape, CPU(output_type)));
}
@ -1047,7 +1047,7 @@ struct HelperInterpNearest : public HelperInterpBase {
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;
for (const auto j C10_UNUSED : c10::irange(interp_size)) {
for (C10_UNUSED const auto j : c10::irange(interp_size)) {
output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
// Defines weights for consistency, but not used
output.emplace_back(at::ones(new_shape, CPU(output_type)));

View File

@ -71,7 +71,7 @@ inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef<char*
// Deals with output
auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize();
int result = can_vectorize_up_to(result_size, pointers[0]);
auto result = can_vectorize_up_to(result_size, pointers[0]);
// Incorporates input(s)
auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();

View File

@ -51,7 +51,7 @@ static void layer_norm_with_mean_rstd_out(
for (const auto idx : c10::irange(axis)) {
stat_shape.emplace_back(input_shape[idx]);
}
for (const auto idx C10_UNUSED : c10::irange(axis, input.dim())) {
for (C10_UNUSED const auto idx : c10::irange(axis, input.dim())) {
stat_shape.emplace_back(1);
}
@ -256,7 +256,7 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
for (const auto idx : c10::irange(axis)) {
stat_shape.push_back(input_shape[idx]);
}
for (const auto idx C10_UNUSED : c10::irange(axis, input.dim())) {
for (C10_UNUSED const auto idx : c10::irange(axis, input.dim())) {
stat_shape.push_back(1);
}
mean = mean.view(stat_shape);

View File

@ -124,7 +124,7 @@ static void upsample_bilinear2d_out_frame(
const auto* pos1 = i_ptr + h1 * input_width + w1;
float result = h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) +
const float result = h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) +
h1lambda *
(w0lambda * pos1[h1p * input_width] +
w1lambda * pos1[h1p * input_width + w1p]) - input_q_zero_point;

View File

@ -160,7 +160,6 @@ void _csr_matmult(
}
for (C10_UNUSED const auto jj : c10::irange(length)) {
// NOTE: the linked list that encodes col indices
// is not guaranteed to be sorted.
Cj[nnz] = head;

View File

@ -35,14 +35,12 @@ dict_int_int test_dict(dict_int_int& dict) {
// erase via iterators
auto begin = dict.begin();
for (const auto i : c10::irange(20)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(20)) {
begin++;
}
auto end = begin;
for (const auto i : c10::irange(20)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(20)) {
erase_set.insert(end->first);
end++;
}
@ -136,13 +134,11 @@ TEST(OrderedPreservingDictTest, DictCollisions) {
// erase a few entries via iterator
auto begin = dict.begin();
for (const auto j : c10::irange(10)) {
(void)j; // Suppress unused variable warning
for (C10_UNUSED const auto j : c10::irange(10)) {
begin++;
}
auto end = begin;
for (const auto j : c10::irange(7)) {
(void)j; // Suppress unused variable warning
for (C10_UNUSED const auto j : c10::irange(7)) {
erase_set.insert(end->first);
end++;
}

View File

@ -2220,8 +2220,7 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
for (const auto i : c10::irange(
(chunk_count + cross_chunk_shuffle_count - 1) /
cross_chunk_shuffle_count)) {
for (const auto j : c10::irange(chunk_size)) {
(void)j; // Suppress unused variable warning
for (C10_UNUSED const auto j : c10::irange(chunk_size)) {
for (const auto k : c10::irange(cross_chunk_shuffle_count)) {
if (i * cross_chunk_shuffle_count + k < chunk_count) {
expected_result.push_back(i * cross_chunk_shuffle_count + k);

View File

@ -1343,8 +1343,7 @@ TEST_F(FunctionalTest, GumbelSoftmax) {
auto counts = torch::zeros_like(logits);
torch::Tensor y_draw;
for (const auto i : c10::irange(num_draws)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(num_draws)) {
y_draw =
F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true));
counts += y_draw;

View File

@ -123,8 +123,7 @@ bool test_mnist(
torch::Device device(with_cuda ? torch::kCUDA : torch::kCPU);
model->to(device);
for (const auto epoch : c10::irange(number_of_epochs)) {
(void)epoch; // Suppress unused variable warning
for (C10_UNUSED const auto epoch : c10::irange(number_of_epochs)) {
// NOLINTNEXTLINE(performance-for-range-copy)
for (torch::data::Example<> batch : *data_loader) {
auto data = batch.data.to(device);

View File

@ -3511,8 +3511,7 @@ void _multihead_attn_test_helper(
std::uniform_int_distribution<int> d_2_10(2, 10);
std::uniform_int_distribution<int> d_3_10(3, 10);
bool registration_checked = false;
for (const auto i : c10::irange(100)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(100)) {
const auto batch_sz = d_2_10(generator);
const auto seq_len = d_2_10(generator);
const auto d_head = d_3_10(generator);

View File

@ -398,8 +398,7 @@ std::vector<torch::Tensor> PackedSequenceTest_ordered_sequence(
torch::ScalarType tensor_type) {
std::vector<torch::Tensor> seqs;
seqs.reserve(PackedSequenceTest_batch_size);
for (const auto i : c10::irange(PackedSequenceTest_batch_size)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(PackedSequenceTest_batch_size)) {
seqs.emplace_back(torch::empty(
{torch::randint(1, PackedSequenceTest_max_length, {1}).item<int64_t>()},
tensor_type));

View File

@ -12,8 +12,7 @@ struct OperationTest : torch::test::SeedingFixture {
};
TEST_F(OperationTest, Lerp) {
for (const auto i : c10::irange(TEST_AMOUNT)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(TEST_AMOUNT)) {
// test lerp_kernel_scalar
auto start = torch::rand({3, 5});
auto end = torch::rand({3, 5});
@ -37,8 +36,7 @@ TEST_F(OperationTest, Lerp) {
}
TEST_F(OperationTest, Cross) {
for (const auto i : c10::irange(TEST_AMOUNT)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(TEST_AMOUNT)) {
// input
auto a = torch::rand({10, 3});
auto b = torch::rand({10, 3});

View File

@ -157,8 +157,7 @@ void check_exact_values(
TEST(OptimTest, OptimizerAccessors) {
auto options = AdagradOptions(1.0);
std::vector<torch::Tensor> params;
for (const auto i : c10::irange(3)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(3)) {
params.push_back(torch::randn(10));
}
auto optimizer = Adagrad(params, options);

View File

@ -1043,8 +1043,7 @@ TEST(Reductions, ReduceSplitRfactor) {
SimpleIREvaluator cg(s, {b, c});
cg.call({in, out});
for (const auto i : c10::irange(M)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(M)) {
ASSERT_EQ(out[0], 4950);
}
}

View File

@ -3884,8 +3884,7 @@ TEST(Simplify, SimplifyEliminateEmptyFor) {
{
// Flatten many layers around an empty block to an empty block.
StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
for (const auto i : c10::irange(11)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(11)) {
VarHandle loopVar("loopVar", kInt);
last = For::make(loopVar, 0, 10, last);
}
@ -3969,8 +3968,7 @@ TEST(Simplify, SimplifyFlattenBlock) {
{
// Flatten many layers around an empty block to an empty block.
StmtPtr last = alloc<Block>(std::vector<StmtPtr>({}));
for (const auto i : c10::irange(11)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(11)) {
last = alloc<Block>(std::vector<StmtPtr>({last}));
}

View File

@ -12,8 +12,7 @@ torch::List<torch::Tensor> custom_op(
int64_t repeat) {
torch::List<torch::Tensor> output;
output.reserve(repeat);
for (const auto i : c10::irange(repeat)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(repeat)) {
output.push_back(tensor * scalar);
}
return output;

View File

@ -82,8 +82,7 @@ class DataLoaderBase {
// Send one 'quit' message per worker. Since a worker dies (exits its
// thread) after receiving this message, each `QuitWorker()` message will be
// read by exactly one worker.
for (const auto w : c10::irange(options_.workers)) {
(void)w; // Suppress unused variable warning
for (C10_UNUSED const auto w : c10::irange(options_.workers)) {
push_job(QuitWorker());
}
for (auto& worker : workers_) {
@ -146,8 +145,7 @@ class DataLoaderBase {
/// Schedules `requested_jobs` many new batches to be fetched. The actual
/// number of jobs scheduled may be less if the DataLoader exhausts.
void prefetch(size_t requested_jobs) {
for (const auto r : c10::irange(requested_jobs)) {
(void)r; // Suppress unused variable
for (C10_UNUSED const auto r : c10::irange(requested_jobs)) {
if (auto batch_request = get_batch_request()) {
this->push_job(std::move(*batch_request));
} else {

View File

@ -23,8 +23,7 @@ inline std::vector<int64_t> _reverse_repeat_vector(
std::vector<int64_t> ret;
ret.reserve(t.size() * n);
for (auto rit = t.rbegin(); rit != t.rend(); ++rit) {
for (const auto i : c10::irange(n)) {
(void)i; // Suppress unused variable
for (C10_UNUSED const auto i : c10::irange(n)) {
ret.emplace_back(*rit);
}
}

View File

@ -222,8 +222,7 @@ TransformerEncoderImpl::TransformerEncoderImpl(
void TransformerEncoderImpl::reset() {
layers = this->register_module("layers", ModuleList());
for (const auto i : c10::irange(options.num_layers())) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(options.num_layers())) {
layers->push_back(options.encoder_layer()->clone());
}
@ -289,8 +288,7 @@ TransformerDecoderImpl::TransformerDecoderImpl(
void TransformerDecoderImpl::reset() {
layers = this->register_module("layers", ModuleList());
for (const auto i : c10::irange(options.num_layers())) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(options.num_layers())) {
layers->push_back(options.decoder_layer()->clone());
}

View File

@ -1693,8 +1693,7 @@ Tensor repeat_backward(
}
const auto input_dims = input_shape.size();
auto num_unsqueezed = grad.dim() - input_dims;
for (const auto i : c10::irange(num_unsqueezed)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(num_unsqueezed)) {
grad = grad.sum(0, false);
}

View File

@ -44,7 +44,7 @@ struct TORCH_API NotImplemented : public Error {
// @once_differentiable
struct TORCH_API DelayedError : public Node {
DelayedError(std::string msg, int64_t num_inputs) : msg(std::move(msg)) {
for (const auto _ [[maybe_unused]] : c10::irange(num_inputs)) {
for (C10_UNUSED const auto _ [[maybe_unused]] : c10::irange(num_inputs)) {
add_input_metadata(Node::undefined_input());
}
}

View File

@ -154,8 +154,7 @@ struct CollectiveFingerPrint {
// tensor>]
std::vector<at::Tensor> outputs;
outputs.reserve(backend->getSize());
for (const auto i : c10::irange(backend->getSize())) {
std::ignore = i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(backend->getSize())) {
outputs.emplace_back(at::zeros_like(tensor_shape));
}
output_tensors.emplace_back(outputs);

View File

@ -543,7 +543,7 @@ struct slot_list_impl {
size_t size() const {
if (!size_) {
size_ = size_t(0);
for ([[maybe_unused]] const value_type& _ : *(this)) {
for (C10_UNUSED const value_type& _ : *(this)) {
++*size_;
}
}

View File

@ -292,8 +292,7 @@ SourceRange Node::sourceRange() const {
}
static std::ostream& indent(std::ostream& out, size_t level) {
for (const auto i : c10::irange(level)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(level)) {
out << " ";
}
return out;
@ -1768,8 +1767,7 @@ Node* Graph::createTupleSlice(
new_vals.reserve(num_values);
int64_t i = beg;
for (const auto j : c10::irange(num_values)) {
(void)j; // Suppress unused variable warning
for (C10_UNUSED const auto j : c10::irange(num_values)) {
auto idx = insertConstant(IValue(static_cast<int64_t>(i)));
auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i]));
@ -1817,8 +1815,7 @@ Node* Graph::createListUnpack(Value* v, size_t size) {
ListTypePtr list_type = v->type()->expect<ListType>();
TypePtr elem_type = list_type->getElementType();
auto n = create(prim::ListUnpack, {v}, 0);
for (const auto i : c10::irange(size)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(size)) {
n->addOutput()->setType(elem_type);
}
return n;

View File

@ -190,8 +190,7 @@ void toList(Stack& stack) {
"Output annotation list dimension and runtime tensor dimension must match for tolist()");
// Wrap out_ty in a ListType dim times.
for (const auto i : c10::irange(dim_val)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(dim_val)) {
out_ty = at::ListType::create(out_ty);
}

View File

@ -327,8 +327,8 @@ bool FoldFrozenConvMulOrDiv(Block* b) {
// channels-out resize it to the shape that will broadcast to
// weight_tensor when the op is run so we dont change weight size
std::vector<int64_t> weight_compatible_size = {out_channels};
for (const auto i : c10::irange(1, weight_tensor.ndimension())) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i :
c10::irange(1, weight_tensor.ndimension())) {
weight_compatible_size.push_back(1);
}

View File

@ -829,8 +829,7 @@ struct GraphFuser {
}
bchunk->removeInput(producer_index);
for (const auto i : c10::irange(nchunks)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(nchunks)) {
bchunk->eraseOutput(nchunks * producer_index);
}

View File

@ -128,8 +128,7 @@ void repeatBody(Block* body, size_t times, Block* dest) {
std::vector<Value*> io = dest->inputs().vec();
TORCH_INTERNAL_ASSERT(
!body->inputs().at(0)->hasUses(), "loop counter should be unused");
for (const auto i : c10::irange(times)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(times)) {
io[0] = body->inputs().at(0);
io = insertBlockCopy(*graph, body, io);
}

View File

@ -324,23 +324,19 @@ void unpackQuantizedWeightsHelper(
const int64_t kSpatialDim = config_vals.at(0);
// skip kSpatialDim
unsigned idx = 1;
for (const auto i : c10::irange(kSpatialDim)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
stride_int.emplace_back(config_vals.at(idx));
idx++;
}
for (const auto i : c10::irange(kSpatialDim)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
padding_int.emplace_back(config_vals.at(idx));
idx++;
}
for (const auto i : c10::irange(kSpatialDim)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
dilation_int.emplace_back(config_vals.at(idx));
idx++;
}
for (const auto i : c10::irange(kSpatialDim)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
output_padding_int.emplace_back(config_vals.at(idx));
idx++;
}

View File

@ -75,8 +75,7 @@ std::string getQuantizeForScalar(const std::string& value) {
)" +
value + "_tensor : Tensor = aten::scalar_tensor(" + value + ", " + value +
"_float_scalar_type";
for (const auto i : c10::irange(3)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(3)) {
quantize_pattern += ", " + value + "_none";
}
quantize_pattern += ")";

View File

@ -134,7 +134,8 @@ void initScriptListBindings(PyObject* module) {
auto seq = std::make_shared<ScriptList>(self->type());
for (const auto i [[maybe_unused]] : c10::irange(slicelength)) {
for (C10_UNUSED const auto i [[maybe_unused]] :
c10::irange(slicelength)) {
seq->append(self->getItem(static_cast<ptrdiff_t>(start)));
start += step;
}

View File

@ -311,8 +311,7 @@ void listMulIntLeftInPlace(Stack& stack) {
list.clear();
} else if (n > 1) {
size_t list_size = list.size();
for (const auto i : c10::irange(1, n)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(1, n)) {
for (const auto j : c10::irange(list_size)) {
list.push_back(list.get(j));
}
@ -330,8 +329,7 @@ void listMulIntLeft(Stack& stack) {
const auto size = list.size() * n;
ret.reserve(size);
for (const auto i : c10::irange(n)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(n)) {
for (IValue e : list) {
ret.push_back(std::move(e));
}
@ -348,8 +346,7 @@ void listMulIntRight(Stack& stack) {
const auto size = list.size() * n;
ret.reserve(size);
for (const auto i : c10::irange(n)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(n)) {
for (IValue e : list) {
ret.push_back(std::move(e));
}
@ -382,8 +379,7 @@ void listSlice(Stack& stack) {
sliced_list.reserve(num_values);
int i = start;
for (const auto j : c10::irange(num_values)) {
(void)j; // Suppress unused variable warning
for (C10_UNUSED const auto j : c10::irange(num_values)) {
sliced_list.push_back(list.get(i));
i += step;
}

View File

@ -38,8 +38,7 @@ std::string stringSlice(
int64_t i = start_val;
std::string result = "";
for (const auto j : c10::irange(num_vals)) {
(void)j; // Suppress unused variable warning
for (C10_UNUSED const auto j : c10::irange(num_vals)) {
result += string[i];
i += step;
}

View File

@ -1581,8 +1581,7 @@ float BlockRunner::benchmark_model(
const bool is_kwargs_empty = kwargs_list.empty();
const KeywordArgs empty_kwargs;
for (const auto _n_run : c10::irange(warmup_runs)) {
(void)_n_run; // Suppress unused variable warning
for (C10_UNUSED const auto _n_run : c10::irange(warmup_runs)) {
const auto num_args = static_cast<uint32_t>(args_list.size());
for (const auto j : c10::irange(num_args)) {
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
@ -1592,8 +1591,7 @@ float BlockRunner::benchmark_model(
}
}
caffe2::Timer timer;
for (const auto _n_run : c10::irange(main_runs)) {
(void)_n_run; // Suppress unused variable warning
for (C10_UNUSED const auto _n_run : c10::irange(main_runs)) {
const auto num_args = static_cast<uint32_t>(args_list.size());
for (const auto j : c10::irange(num_args)) {
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
@ -1745,8 +1743,7 @@ BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops(
results.first_iter_time = timer.MilliSeconds();
// warmup runs
for (const auto _n_run : c10::irange(warmup_runs)) {
(void)_n_run; // Suppress unused variable warning
for (C10_UNUSED const auto _n_run : c10::irange(warmup_runs)) {
const auto num_args = static_cast<uint32_t>(args_list.size());
for (const auto j : c10::irange(num_args)) {
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
@ -1757,8 +1754,7 @@ BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops(
}
// main runs
for (const auto i : c10::irange(main_runs)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(main_runs)) {
const auto num_args = static_cast<uint32_t>(args_list.size());
for (const auto j : c10::irange(num_args)) {
set_inputs(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);

View File

@ -436,8 +436,7 @@ struct PythonPrintImpl {
size_t level = 0;
// indent to the current indent level
TaggedStringStream& indent() {
for (const auto i : c10::irange(level)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(level)) {
body_ << " ";
}
return body_;
@ -1299,8 +1298,7 @@ struct PythonPrintImpl {
IValue createBroadList(dtype value, const int64_t& N) {
c10::List<dtype> repeated;
repeated.reserve(N);
for (const auto i : c10::irange(N)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(N)) {
repeated.push_back(value);
}
return repeated;

View File

@ -19,8 +19,7 @@ class ThreadPool {
public:
explicit ThreadPool(size_t num_threads) {
threads_.reserve(num_threads);
for (const auto i : c10::irange(num_threads)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(num_threads)) {
threads_.emplace_back([this]() {
c10::setThreadName("pt_thread_pool");
Worker();

View File

@ -53,8 +53,7 @@ static void recursive_apply(
}
auto n = sizes[dim];
for (const auto i : c10::irange(n)) {
(void)i; // Suppress unused variable warning
for (C10_UNUSED const auto i : c10::irange(n)) {
recursive_apply(sizes, scalarType, dim + 1, fn, strided_data);
for (auto& td : strided_data) {
td.step(dim);