[BE] Enforce sign-compare (#96723)

Number of OSS PR were reverted, because new signed-unsigned comparison warnings, which are treated as errors in some internal builds.
Not sure how those selective rules are applied, but this PR removes `-Wno-sign-compare` from PyTorch codebase.

The only tricky part in this PR, as making sure that non-ASCII character detection works for both signed and unsigned chars  here:
6e3d51b08a/torch/csrc/jit/serialization/python_print.cpp (L926)

Exclude several files from sign-compare if flash attention is used, due to the violation in cutlass, to be fixed by https://github.com/NVIDIA/cutlass/pull/869
Do not try to fix sign compare violations in caffe2 codebase
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96723
Approved by: https://github.com/albanD
This commit is contained in:
Nikita Shulga
2023-03-15 06:04:16 +00:00
committed by PyTorch MergeBot
parent 96c745dfdc
commit a229e78544
78 changed files with 194 additions and 188 deletions

View File

@ -39,7 +39,7 @@ namespace native {
c10::SmallVector<std::string> get_extra_args_typenames(const c10::SmallVector<at::Scalar>& extra_args) { c10::SmallVector<std::string> get_extra_args_typenames(const c10::SmallVector<at::Scalar>& extra_args) {
c10::SmallVector<std::string> args_typenames(extra_args.size()); c10::SmallVector<std::string> args_typenames(extra_args.size());
for (auto i = 0; i < extra_args.size(); ++i) { for (const auto i : c10::irange(extra_args.size())) {
args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type()); args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type());
} }
return args_typenames; return args_typenames;

View File

@ -333,7 +333,7 @@ void conv_depthwise_shape_check(
if (grad_output.defined()) { if (grad_output.defined()) {
auto expected_output_size = conv_output_size(input.sizes(), weight.sizes(), auto expected_output_size = conv_output_size(input.sizes(), weight.sizes(),
padding, stride, dilation); padding, stride, dilation);
TORCH_CHECK(grad_output.dim() == expected_output_size.size(), TORCH_CHECK(static_cast<size_t>(grad_output.dim()) == expected_output_size.size(),
"Expect grad_output to be ", "Expect grad_output to be ",
expected_output_size.size(), "D, got ", expected_output_size.size(), "D, got ",
grad_output.dim(), "D."); grad_output.dim(), "D.");

View File

@ -132,7 +132,7 @@ FOREACH_BINARY_OP_SCALARLIST(all_types_complex_half_bfloat16, pow, power_functor
// In the case of subtraction, we dont allow scalar to be boolean following the torch.sub logic // In the case of subtraction, we dont allow scalar to be boolean following the torch.sub logic
void foreach_tensor_sub_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) { void foreach_tensor_sub_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors, scalars); check_foreach_api_restrictions(tensors, scalars);
for (int i = 0; i < tensors.size(); i++) { for (const auto i: c10::irange(tensors.size())) {
sub_check(tensors[i], scalars[i]); sub_check(tensors[i], scalars[i]);
} }
@ -147,7 +147,7 @@ void foreach_tensor_sub_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef
std::vector<Tensor> foreach_tensor_sub_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<Scalar> scalars) { std::vector<Tensor> foreach_tensor_sub_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors, scalars); check_foreach_api_restrictions(tensors, scalars);
for (int i = 0; i < tensors.size(); i++) { for (const auto i: c10::irange(tensors.size())) {
sub_check(tensors[i], scalars[i]); sub_check(tensors[i], scalars[i]);
} }

View File

@ -53,7 +53,7 @@ static void launch_kernel(int64_t N, const func_t& f) {
template <typename func_t> template <typename func_t>
void gpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, const func_t& f) { void gpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, const func_t& f) {
int num_indices = index_size.size(); int num_indices = index_size.size();
AT_ASSERT(num_indices == index_stride.size()); AT_ASSERT(static_cast<size_t>(num_indices) == index_stride.size());
AT_ASSERT(num_indices == iter.ntensors() - 2); AT_ASSERT(num_indices == iter.ntensors() - 2);
if (iter.numel() == 0) { if (iter.numel() == 0) {

View File

@ -226,8 +226,8 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
int64_t batch_size = log_probs.size(1); int64_t batch_size = log_probs.size(1);
int64_t num_labels = log_probs.size(2); int64_t num_labels = log_probs.size(2);
TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range"); TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range");
TORCH_CHECK(input_lengths.size() == batch_size, "input_lengths must be of size batch_size"); TORCH_CHECK(input_lengths.size() == static_cast<size_t>(batch_size), "input_lengths must be of size batch_size");
TORCH_CHECK(target_lengths.size() == batch_size, "target_lengths must be of size batch_size"); TORCH_CHECK(target_lengths.size() == static_cast<size_t>(batch_size), "target_lengths must be of size batch_size");
int64_t tg_target_stride; int64_t tg_target_stride;

View File

@ -174,7 +174,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
// Now we loop // Now we loop
int batchCounter = 0; int batchCounter = 0;
int64_t offset = 0; int64_t offset = 0;
for (int i = 0; i < inputs.size() ; i += batch_size) { for (unsigned i = 0; i < inputs.size() ; i += batch_size) {
for (batchCounter = 0; for (batchCounter = 0;
batchCounter < batch_size && batchCounter < batch_size &&
(i+batchCounter) < inputs.size(); (i+batchCounter) < inputs.size();

View File

@ -44,7 +44,7 @@ struct HermitianSymmetryOffsetCalculator {
} }
mirror_dim_ = 0; mirror_dim_ = 0;
for (int64_t i = 0; i < dim.size(); ++i) { for (const auto i: c10::irange(dim.size())) {
mirror_dim_ |= (uint32_t{1} << dim[i]); mirror_dim_ |= (uint32_t{1} << dim[i]);
} }
} }

View File

@ -258,7 +258,7 @@ bool CUDA_tensor_histogram(
memType = CUDAHistogramMemoryType::SHARED; memType = CUDAHistogramMemoryType::SHARED;
} else if ( } else if (
nbins < THRESH_NUMBER_BINS_FOR_GLOBAL_MEM && nbins < THRESH_NUMBER_BINS_FOR_GLOBAL_MEM &&
multiBlockMem < (maxGlobalMem / 2)) { multiBlockMem < static_cast<size_t>(maxGlobalMem / 2)) {
// check against half of free mem to be extra safe // check against half of free mem to be extra safe
// due to cached allocator, we may anyway have slightly more free mem // due to cached allocator, we may anyway have slightly more free mem
memType = CUDAHistogramMemoryType::MULTI_BLOCK; memType = CUDAHistogramMemoryType::MULTI_BLOCK;

View File

@ -141,7 +141,7 @@ void calculate_mode(
// to calculate the mode for --> we do this by manually doing the stride // to calculate the mode for --> we do this by manually doing the stride
// calculations to get an offset // calculations to get an offset
scalar_t* data = self.data_ptr<scalar_t>(); scalar_t* data = self.data_ptr<scalar_t>();
for (int64_t i = 0; i < position.size(); i++) { for (int64_t i = 0; i < static_cast<int64_t>(position.size()); i++) {
data += position[i] * ensure_nonempty_stride(self, i); data += position[i] * ensure_nonempty_stride(self, i);
} }
@ -159,7 +159,7 @@ void calculate_mode(
scalar_t* values_data = values.data_ptr<scalar_t>(); scalar_t* values_data = values.data_ptr<scalar_t>();
int64_t* indices_data = indices.data_ptr<int64_t>(); int64_t* indices_data = indices.data_ptr<int64_t>();
for (int64_t i = 0; i < position.size(); i++) { for (int64_t i = 0; i < static_cast<int64_t>(position.size()); i++) {
int64_t pos = position[i]; int64_t pos = position[i];
values_data += ensure_nonempty_stride(values, i) * pos; values_data += ensure_nonempty_stride(values, i) * pos;
indices_data += ensure_nonempty_stride(indices, i) * pos; indices_data += ensure_nonempty_stride(indices, i) * pos;

View File

@ -796,7 +796,7 @@ void LayerNormKernelImplInternal(
constexpr int num_vec_elems = vec_size; constexpr int num_vec_elems = vec_size;
constexpr int alignment = num_vec_elems * sizeof(T); constexpr int alignment = num_vec_elems * sizeof(T);
if ((std::is_same<T, float>::value || std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) && if ((std::is_same<T, float>::value || std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) &&
N <= 1ULL << std::numeric_limits<float>::digits && N % num_vec_elems == 0 && N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && N % num_vec_elems == 0 &&
can_vectorize(X_data, alignment) && can_vectorize(Y_data, alignment)) { can_vectorize(X_data, alignment) && can_vectorize(Y_data, alignment)) {
launch_vectorized_layer_norm_kernel(static_cast<int>(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); launch_vectorized_layer_norm_kernel(static_cast<int>(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data);
} else { } else {
@ -1356,10 +1356,10 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_cuda(
const size_t axis = input.dim() - normalized_shape.size(); const size_t axis = input.dim() - normalized_shape.size();
std::vector<int64_t> stat_shape; std::vector<int64_t> stat_shape;
for (size_t idx = 0; idx < axis; ++idx) { for (const auto idx: c10::irange(axis)) {
stat_shape.push_back(input_shape[idx]); stat_shape.push_back(input_shape[idx]);
} }
for (size_t idx = axis; idx < input.dim(); ++idx) { for (const auto C10_UNUSED idx: c10::irange(axis, input.dim())) {
stat_shape.push_back(1); stat_shape.push_back(1);
} }

View File

@ -373,7 +373,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
if (remove_invalid) { if (remove_invalid) {
cudnn_frontend::executionPlans_t new_valid_plans; cudnn_frontend::executionPlans_t new_valid_plans;
for (auto &plan : valid_plans) { for (auto &plan : valid_plans) {
if (plan.getWorkspaceSize() <= max_workspace_size) { if (static_cast<size_t>(plan.getWorkspaceSize()) <= max_workspace_size) {
new_valid_plans.emplace_back(std::move(plan)); new_valid_plans.emplace_back(std::move(plan));
} }
} }

View File

@ -35,7 +35,7 @@ size_t compute_strided_size(const at::Tensor& t) {
} }
bool is_strided_contiguous(const at::Tensor& t) { bool is_strided_contiguous(const at::Tensor& t) {
return compute_strided_size(t) == t.numel(); return compute_strided_size(t) == static_cast<size_t>(t.numel());
} }
// Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type(). // Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().

View File

@ -156,11 +156,11 @@ static void validateInputData(const TensorIteratorBase& iter,
bool accumulate) { bool accumulate) {
using namespace mps; using namespace mps;
int64_t num_indices = index_size.size(); const auto num_indices = index_size.size();
TORCH_CHECK(num_indices <= 16, "Current limit allows up to 16 indices to be used in MPS indexing kernels"); TORCH_CHECK(num_indices <= 16, "Current limit allows up to 16 indices to be used in MPS indexing kernels");
AT_ASSERT(num_indices == index_stride.size()); AT_ASSERT(num_indices == index_stride.size());
AT_ASSERT(num_indices == iter.ntensors() - 2); AT_ASSERT(static_cast<int>(num_indices) == iter.ntensors() - 2);
const Tensor& inputTensor = iter.tensor(1); const Tensor& inputTensor = iter.tensor(1);
if (accumulate) { if (accumulate) {
@ -589,8 +589,8 @@ Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) {
std::vector<int64_t> shape_data(num_input_dims); std::vector<int64_t> shape_data(num_input_dims);
// Calculate new shape // Calculate new shape
for (auto i : c10::irange(num_input_dims)) { for (const auto i : c10::irange(num_input_dims)) {
if (i == dim) { if (i == static_cast<decltype(i)>(dim)) {
shape_data[i] = num_indices; shape_data[i] = num_indices;
} else { } else {
shape_data[i] = input_shape[i]; shape_data[i] = input_shape[i];

View File

@ -1000,21 +1000,21 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(const Tensor& grad_ou
NSMutableArray<NSNumber*>* gamma_axes = [NSMutableArray<NSNumber*> arrayWithCapacity:num_channel_dims]; NSMutableArray<NSNumber*>* gamma_axes = [NSMutableArray<NSNumber*> arrayWithCapacity:num_channel_dims];
for (int i = 0; i < num_channel_dims; i++) for (const auto i : c10::irange(num_channel_dims))
gamma_axes[i] = [NSNumber numberWithInt:i]; gamma_axes[i] = [NSNumber numberWithInt:static_cast<int>(i)];
// Axes along which to reduce to get "batch norm" gradient // Axes along which to reduce to get "batch norm" gradient
// This will be applied on shape [1, M, -1] // This will be applied on shape [1, M, -1]
NSMutableArray<NSNumber*>* bn_axes = [NSMutableArray<NSNumber*> arrayWithCapacity:num_normalized_dims]; NSMutableArray<NSNumber*>* bn_axes = [NSMutableArray<NSNumber*> arrayWithCapacity:num_normalized_dims];
for (int i = 0; i < num_normalized_dims; i++) for (const auto i : c10::irange(num_normalized_dims))
bn_axes[i] = [NSNumber numberWithInt:(1 + 1 + i)]; bn_axes[i] = [NSNumber numberWithInt:static_cast<int>(1 + 1 + i)];
// Shape of input to do "batch norm" backward // Shape of input to do "batch norm" backward
// This is [1, M, -1] // This is [1, M, -1]
NSMutableArray<NSNumber*>* bn_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)]; NSMutableArray<NSNumber*>* bn_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)];
bn_shape[0] = [NSNumber numberWithInt:1]; bn_shape[0] = [NSNumber numberWithInt:1];
bn_shape[1] = [NSNumber numberWithInt:M]; bn_shape[1] = [NSNumber numberWithInt:M];
for (int i = 0; i < num_normalized_dims; i++) for (const auto i : c10::irange(num_normalized_dims))
bn_shape[i + 2] = input_shape[i + num_channel_dims]; bn_shape[i + 2] = input_shape[i + num_channel_dims];
// Shape of mean to do "batch norm" backward // Shape of mean to do "batch norm" backward
@ -1023,7 +1023,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(const Tensor& grad_ou
[NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)]; [NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)];
bn_mean_shape[0] = [NSNumber numberWithInt:1]; bn_mean_shape[0] = [NSNumber numberWithInt:1];
bn_mean_shape[1] = [NSNumber numberWithInt:M]; bn_mean_shape[1] = [NSNumber numberWithInt:M];
for (int i = 0; i < num_normalized_dims; i++) for (const auto i : c10::irange(num_normalized_dims))
bn_mean_shape[i + 2] = [NSNumber numberWithInt:1]; bn_mean_shape[i + 2] = [NSNumber numberWithInt:1];
// Shape of gamma to multiply with "batch norm" backward // Shape of gamma to multiply with "batch norm" backward
@ -1032,7 +1032,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(const Tensor& grad_ou
[NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)]; [NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)];
bn_gamma_shape[0] = [NSNumber numberWithInt:1]; bn_gamma_shape[0] = [NSNumber numberWithInt:1];
bn_gamma_shape[1] = [NSNumber numberWithInt:1]; bn_gamma_shape[1] = [NSNumber numberWithInt:1];
for (int i = 0; i < num_normalized_dims; i++) for (const auto i : c10::irange(num_normalized_dims))
bn_gamma_shape[i + 2] = input_shape[i + num_channel_dims]; bn_gamma_shape[i + 2] = input_shape[i + num_channel_dims];
string key = "layer_norm_backward_mps:" + std::to_string(has_weight) + ":" + string key = "layer_norm_backward_mps:" + std::to_string(has_weight) + ":" +

View File

@ -136,8 +136,9 @@ void reduction_out_mps(const Tensor& input_t,
IntArrayRef dim = opt_dim.value(); IntArrayRef dim = opt_dim.value();
for (const auto dim_val : dim) { for (const auto dim_val : dim) {
auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size()); auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
TORCH_CHECK(wrap_dim < (input_shape.size() == 0 ? input_t.numel() : input_shape.size()), TORCH_CHECK(
func_name + ": reduction dim must be in the range of input shape") wrap_dim < static_cast<decltype(wrap_dim)>(input_shape.size() == 0 ? input_t.numel() : input_shape.size()),
func_name + ": reduction dim must be in the range of input shape")
} }
} }
@ -395,7 +396,8 @@ void impl_func_norm_mps(const Tensor& input_tensor,
for (const auto dim_val : dim) { for (const auto dim_val : dim) {
auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size()); auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
TORCH_CHECK(wrap_dim < input_shape.size(), "norm_out_mps: reduction dim must be in the range of input shape") TORCH_CHECK(wrap_dim < static_cast<decltype(wrap_dim)>(input_shape.size()),
"norm_out_mps: reduction dim must be in the range of input shape")
} }
auto cache_ = MPSGraphCache::getInstance(); auto cache_ = MPSGraphCache::getInstance();
@ -663,8 +665,8 @@ Tensor std_var_common_impl_mps(const Tensor& input_t,
string errMessage = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps"; string errMessage = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
errMessage += ": reduction dim must be in the range of input shape"; errMessage += ": reduction dim must be in the range of input shape";
for (const auto dim : dim_value) { for (const auto dim : dim_value) {
auto wrap_dim = maybe_wrap_dim(dim, input_shape.size()); auto wrap_dim = maybe_wrap_dim(dim, num_input_dims);
TORCH_CHECK(wrap_dim < input_shape.size(), errMessage.c_str()) TORCH_CHECK(wrap_dim < static_cast<decltype(wrap_dim)>(input_shape.size()), errMessage.c_str())
} }
} }

View File

@ -207,7 +207,7 @@ void computeRepeatIndices(index_t* repeat_ptr,
[computeEncoder setBytes:&size length:sizeof(size) atIndex:3]; [computeEncoder setBytes:&size length:sizeof(size) atIndex:3];
MTLSize gridSize = MTLSizeMake(size, 1, 1); MTLSize gridSize = MTLSizeMake(size, 1, 1);
NSUInteger threadsPerThreadgroup_ = pipelineState.maxTotalThreadsPerThreadgroup; NSUInteger threadsPerThreadgroup_ = pipelineState.maxTotalThreadsPerThreadgroup;
if (threadsPerThreadgroup_ > size) { if (threadsPerThreadgroup_ > static_cast<NSUInteger>(size)) {
threadsPerThreadgroup_ = size; threadsPerThreadgroup_ = size;
} }
MTLSize threadsPerThreadgroup = MTLSizeMake(threadsPerThreadgroup_, 1, 1); MTLSize threadsPerThreadgroup = MTLSizeMake(threadsPerThreadgroup_, 1, 1);

View File

@ -17,7 +17,7 @@ namespace at::native {
std::vector<long long> getTensorShape(MPSGraphTensor* mpsTensor) { std::vector<long long> getTensorShape(MPSGraphTensor* mpsTensor) {
std::vector<long long> output_dimensions = {}; std::vector<long long> output_dimensions = {};
auto dims = mpsTensor.shape; auto dims = mpsTensor.shape;
for (int i = 0; i < [dims count]; i++) { for (NSUInteger i = 0; i < [dims count]; i++) {
output_dimensions.push_back([dims[i] intValue]); output_dimensions.push_back([dims[i] intValue]);
} }
return output_dimensions; return output_dimensions;

View File

@ -97,7 +97,7 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self,
if (dimOpt.has_value() && [shape count] != 1) { if (dimOpt.has_value() && [shape count] != 1) {
NSMutableArray* axes = [[NSMutableArray alloc] initWithCapacity:[shape count] - 1]; NSMutableArray* axes = [[NSMutableArray alloc] initWithCapacity:[shape count] - 1];
for (const auto axis : c10::irange([shape count])) { for (const auto axis : c10::irange([shape count])) {
if (axis != dim) { if (static_cast<decltype(dim)>(axis) != dim) {
[axes addObject:[NSNumber numberWithUnsignedInteger:axis]]; [axes addObject:[NSNumber numberWithUnsignedInteger:axis]];
} }
} }

View File

@ -70,7 +70,7 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
feeds[cachedGraph->storageOffsetTensor] = getMPSGraphTensorFromScalar(stream, storageOffsetScalar); feeds[cachedGraph->storageOffsetTensor] = getMPSGraphTensorFromScalar(stream, storageOffsetScalar);
std::vector<MPSScalar> strideScalars(sizes.size()); std::vector<MPSScalar> strideScalars(sizes.size());
for (int i = 0; i < sizes.size(); i++) { for (const auto i : c10::irange(sizes.size())) {
strideScalars[i] = getMPSScalar(strides[i], ScalarType::Int); strideScalars[i] = getMPSScalar(strides[i], ScalarType::Int);
feeds[cachedGraph->strideTensors[i]] = getMPSGraphTensorFromScalar(stream, strideScalars[i]); feeds[cachedGraph->strideTensors[i]] = getMPSGraphTensorFromScalar(stream, strideScalars[i]);
} }
@ -133,7 +133,7 @@ NSDictionary* getStrideToDimLengthOffsetDict(MPSGraphTensor* tensor, NSUInteger
// Detect only expand dims, allows for duplicate strides // Detect only expand dims, allows for duplicate strides
MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph* graph, MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph* graph,
MPSGraphTensor* inputTensor, MPSGraphTensor* inputTensor,
int dstRank, size_t dstRank,
const IntArrayRef& dstSizes, const IntArrayRef& dstSizes,
const IntArrayRef& dstStrides, const IntArrayRef& dstStrides,
int offset) { int offset) {
@ -185,7 +185,7 @@ MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph* graph,
// Detect contiguous reshapes, no slicing // Detect contiguous reshapes, no slicing
MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph* graph, MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph* graph,
MPSGraphTensor* inputTensor, MPSGraphTensor* inputTensor,
int dstRank, size_t dstRank,
const IntArrayRef& dstSizes, const IntArrayRef& dstSizes,
const IntArrayRef& dstStrides, const IntArrayRef& dstStrides,
int offset) { int offset) {
@ -228,7 +228,7 @@ MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph* graph,
MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph, MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
MPSGraphTensor* inputTensor, MPSGraphTensor* inputTensor,
int dstRank, size_t dstRank,
const IntArrayRef& dstSizes, const IntArrayRef& dstSizes,
const IntArrayRef& dstStrides, const IntArrayRef& dstStrides,
int offset) { int offset) {
@ -236,7 +236,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
{ {
BOOL allUnique = YES; BOOL allUnique = YES;
NSMutableSet* uniqueStrides = [[NSMutableSet alloc] init]; NSMutableSet* uniqueStrides = [[NSMutableSet alloc] init];
for (NSInteger dstDim = 0; (dstDim < dstRank) && allUnique; dstDim++) { for (NSUInteger dstDim = 0; (dstDim < dstRank) && allUnique; dstDim++) {
int stride = dstStrides[dstDim]; int stride = dstStrides[dstDim];
NSNumber* strideObj = [NSNumber numberWithInt:stride]; NSNumber* strideObj = [NSNumber numberWithInt:stride];
allUnique &= (stride == 0 || ![uniqueStrides containsObject:strideObj]); allUnique &= (stride == 0 || ![uniqueStrides containsObject:strideObj]);
@ -247,7 +247,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
return nil; return nil;
// Skip for zero in dst shape // Skip for zero in dst shape
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++)
if (dstSizes[dstDim] == 0) { if (dstSizes[dstDim] == 0) {
return nil; return nil;
} }
@ -277,7 +277,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
std::vector<int32_t> dstDimToSliceOffset(dstRank); std::vector<int32_t> dstDimToSliceOffset(dstRank);
bool needsBroadcast = false; bool needsBroadcast = false;
{ {
for (NSInteger dstDim = dstRank - 1; dstDim >= 0; dstDim--) { for (auto dstDim = dstRank - 1; dstDim >= 0; dstDim--) {
if (dstStrides[dstDim] == 0) { if (dstStrides[dstDim] == 0) {
// This dimension should be a broadcast // This dimension should be a broadcast
needsBroadcast = true; needsBroadcast = true;
@ -318,7 +318,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
[missingSrcStrides addObject:[NSNumber numberWithInteger:stride]]; [missingSrcStrides addObject:[NSNumber numberWithInteger:stride]];
stride *= [[flatInputTensor shape][srcDim] integerValue]; stride *= [[flatInputTensor shape][srcDim] integerValue];
} }
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) { for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++) {
[missingSrcStrides removeObject:[NSNumber numberWithInteger:dstStrides[dstDim]]]; [missingSrcStrides removeObject:[NSNumber numberWithInteger:dstStrides[dstDim]]];
} }
} }
@ -344,7 +344,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
// TODO: Use Transpose API // TODO: Use Transpose API
BOOL needsTranspose = NO; BOOL needsTranspose = NO;
for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++) for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++)
needsTranspose |= ([dstDimOrder[dstDim] intValue] != dstDim); needsTranspose |= ([dstDimOrder[dstDim] intValue] != static_cast<int>(dstDim));
if (needsTranspose) if (needsTranspose)
transposedTensor = permuteTensor(graph, transposedTensor, dstDimOrder); transposedTensor = permuteTensor(graph, transposedTensor, dstDimOrder);
} }
@ -385,7 +385,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
if (needsBroadcast) { if (needsBroadcast) {
NSMutableArray* broadcastShape = [[NSMutableArray alloc] init]; NSMutableArray* broadcastShape = [[NSMutableArray alloc] init];
NSMutableArray* expandAxes = [[NSMutableArray alloc] init]; NSMutableArray* expandAxes = [[NSMutableArray alloc] init];
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) { for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++) {
[broadcastShape addObject:[NSNumber numberWithInt:dstSizes[dstDim]]]; [broadcastShape addObject:[NSNumber numberWithInt:dstSizes[dstDim]]];
if (dstStrides[dstDim] == 0) if (dstStrides[dstDim] == 0)
[expandAxes addObject:[NSNumber numberWithInt:dstDim]]; [expandAxes addObject:[NSNumber numberWithInt:dstDim]];
@ -408,7 +408,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
MPSGraphTensor* asStridedLayer_pattern(MPSGraph* graph, MPSGraphTensor* asStridedLayer_pattern(MPSGraph* graph,
MPSGraphTensor* inputTensor, MPSGraphTensor* inputTensor,
int dstRank, size_t dstRank,
const IntArrayRef& dstSizes, const IntArrayRef& dstSizes,
const IntArrayRef& dstStrides, const IntArrayRef& dstStrides,
int offset) { int offset) {
@ -503,7 +503,7 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape* mp
MPSNDArrayDescriptor* srcTensorNDArrayDesc = nil; MPSNDArrayDescriptor* srcTensorNDArrayDesc = nil;
MPSNDArray* srcTensorNDArray = nil; MPSNDArray* srcTensorNDArray = nil;
id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer(); id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer();
int64_t base_idx = 0; size_t base_idx = 0;
std::vector<int64_t> src_base_shape_vec; std::vector<int64_t> src_base_shape_vec;
@ -574,7 +574,7 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph,
@autoreleasepool { @autoreleasepool {
std::vector<int32_t> sizeArray(shape_size); std::vector<int32_t> sizeArray(shape_size);
const int64_t int_max = std::numeric_limits<int32_t>::max(); const int64_t int_max = std::numeric_limits<int32_t>::max();
for (int i = 0; i < shape_size; i++) { for (const auto i : c10::irange(shape_size)) {
TORCH_CHECK(size[i] <= int_max); TORCH_CHECK(size[i] <= int_max);
sizeArray[i] = static_cast<int32_t>(size[i]); sizeArray[i] = static_cast<int32_t>(size[i]);
} }
@ -584,7 +584,7 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph,
dataType:MPSDataTypeInt32]; dataType:MPSDataTypeInt32];
MPSGraphTensor* indicesTensor = nil; MPSGraphTensor* indicesTensor = nil;
// create stride Tensors for each rank of the input tensor // create stride Tensors for each rank of the input tensor
for (int i = 0; i < shape_size; i++) { for (int i = 0; i < static_cast<int>(shape_size); i++) {
MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis:(-i - 1) withShapeTensor:shapeTensor name:nil]; MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis:(-i - 1) withShapeTensor:shapeTensor name:nil];
MPSGraphTensor* strideTensor = cachedGraph->strideTensors[shape_size - i - 1]; MPSGraphTensor* strideTensor = cachedGraph->strideTensors[shape_size - i - 1];
MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor:rangeTensor MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor:rangeTensor
@ -702,7 +702,7 @@ static ViewCachedGraph* createViewGraph(const Tensor& self,
// Self is the input tensor we are creating view of // Self is the input tensor we are creating view of
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape)); newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape));
newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]); newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]);
for (int i = 0; i < size.size(); i++) { for (const auto C10_UNUSED i : c10::irange(size.size())) {
newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ])); newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]));
} }
if (needsScatter) { if (needsScatter) {
@ -837,7 +837,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
if (kernel_size == 0) { if (kernel_size == 0) {
src_sizes[0] = src_strides[0] = 1; src_sizes[0] = src_strides[0] = 1;
} else { } else {
for (int i = 0; i < kernel_size; i++) { for (const auto i : c10::irange(kernel_size)) {
src_sizes[i] = (uint32_t)(src.sizes()[i]); src_sizes[i] = (uint32_t)(src.sizes()[i]);
src_strides[i] = (uint32_t)(src.strides()[i]); src_strides[i] = (uint32_t)(src.strides()[i]);
} }

View File

@ -656,6 +656,15 @@ if(USE_CUDA)
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/interface.cpp PROPERTIES COMPILE_FLAGS "-DUSE_CUDA=1") set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/interface.cpp PROPERTIES COMPILE_FLAGS "-DUSE_CUDA=1")
endif() endif()
if(USE_FLASH_ATTENTION AND NOT MSVC)
# Cutlass contains a sign-compare violation in its codebase to be fixed by https://github.com/NVIDIA/cutlass/pull/869
set_source_files_properties(
"${PROJECT_SOURCE_DIR}/aten/src/ATen/native/nested/cuda/NestedTensorMatmul.cu"
"${PROJECT_SOURCE_DIR}/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu"
"${PROJECT_SOURCE_DIR}/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_bwd_hdim128.cu"
PROPERTIES COMPILE_FLAGS "-Wno-sign-compare")
endif()
if(BUILD_ONEDNN_GRAPH) if(BUILD_ONEDNN_GRAPH)
list(APPEND Caffe2_CPU_SRCS list(APPEND Caffe2_CPU_SRCS
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp
@ -811,9 +820,10 @@ if(HAVE_SOVERSION)
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
endif() endif()
torch_compile_options(torch_cpu) # see cmake/public/utils.cmake torch_compile_options(torch_cpu) # see cmake/public/utils.cmake
if(HAS_WERROR_SIGN_COMPARE AND WERROR) if(BUILD_CAFFE2 AND NOT MSVC)
# target_compile_options(torch_cpu PRIVATE "-Werror=sign-compare") # Caffe2 has too many signed-unsigned violation, but the framework is dead
set_property(SOURCE ${ATen_CORE_SRCS} ${ATen_CPU_SRCS} APPEND PROPERTY COMPILE_OPTIONS "-Werror=sign-compare") # So no point in fixing those
target_compile_options(torch_cpu PRIVATE "-Wno-sign-compare")
endif() endif()
set_property(SOURCE ${ATen_CORE_SRCS} APPEND set_property(SOURCE ${ATen_CORE_SRCS} APPEND

View File

@ -17,8 +17,8 @@ void BoxCoxNaive(
T* output_ptr) { T* output_ptr) {
constexpr T k_eps = static_cast<T>(1e-6); constexpr T k_eps = static_cast<T>(1e-6);
for (int64_t i = 0; i < N; i++) { for (std::size_t i = 0; i < N; i++) {
for (int64_t j = 0; j < D; j++, data_ptr++, output_ptr++) { for (std::size_t j = 0; j < D; j++, data_ptr++, output_ptr++) {
T lambda1_v = lambda1_ptr[j]; T lambda1_v = lambda1_ptr[j];
T lambda2_v = lambda2_ptr[j]; T lambda2_v = lambda2_ptr[j];
T tmp = std::max(*data_ptr + lambda2_v, k_eps); T tmp = std::max(*data_ptr + lambda2_v, k_eps);

View File

@ -36,8 +36,7 @@ void FloatToFused8BitRowwiseQuantized__base(
output_row_scale_bias[0] = range / 255.0f; output_row_scale_bias[0] = range / 255.0f;
output_row_scale_bias[1] = minimum_element; output_row_scale_bias[1] = minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon); const auto inverse_scale = 255.0f / (range + kEpsilon);
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) for (std::size_t col = 0; col < static_cast<size_t>(input_columns); ++col) {
for (std::size_t col = 0; col < input_columns; ++col) {
output_row[col] = output_row[col] =
std::lrintf((input_row[col] - minimum_element) * inverse_scale); std::lrintf((input_row[col] - minimum_element) * inverse_scale);
} }
@ -58,8 +57,7 @@ void Fused8BitRowwiseQuantizedToFloat__base(
reinterpret_cast<const float*>(input_row + output_columns); reinterpret_cast<const float*>(input_row + output_columns);
float* output_row = output + row * output_columns; float* output_row = output + row * output_columns;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) for (std::size_t col = 0; col < static_cast<std::size_t>(output_columns); ++col) {
for (std::size_t col = 0; col < output_columns; ++col) {
output_row[col] = output_row[col] =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1]; input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
@ -136,8 +134,7 @@ void FloatToFusedNBitRowwiseQuantizedSBHalf__base(
output_row_scale_bias[0] = scale; output_row_scale_bias[0] = scale;
output_row_scale_bias[1] = minimum_element; output_row_scale_bias[1] = minimum_element;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) for (std::size_t col = 0; col < static_cast<size_t>(input_columns); ++col) {
for (std::size_t col = 0; col < input_columns; ++col) {
float X = input_row[col]; float X = input_row[col];
std::uint8_t quantized = std::max( std::uint8_t quantized = std::max(
0, 0,
@ -165,7 +162,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloat__base(
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
(input_columns - 2 * sizeof(at::Half)) * num_elem_per_byte; (input_columns - 2 * sizeof(at::Half)) * num_elem_per_byte;
for (std::size_t row = 0; row < input_rows; ++row) { for (std::size_t row = 0; row < static_cast<size_t>(input_rows); ++row) {
const std::uint8_t* input_row = input + row * input_columns; const std::uint8_t* input_row = input + row * input_columns;
const at::Half* input_row_scale_bias = reinterpret_cast<const at::Half*>( const at::Half* input_row_scale_bias = reinterpret_cast<const at::Half*>(
input_row + input_row +
@ -174,8 +171,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloat__base(
float bias = input_row_scale_bias[1]; float bias = input_row_scale_bias[1];
float* output_row = output + row * output_columns; float* output_row = output + row * output_columns;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) for (std::size_t col = 0; col < static_cast<std::size_t>(output_columns); ++col) {
for (std::size_t col = 0; col < output_columns; ++col) {
std::uint8_t quantized = input_row[col / num_elem_per_byte]; std::uint8_t quantized = input_row[col / num_elem_per_byte];
quantized >>= (col % num_elem_per_byte) * bit_rate; quantized >>= (col % num_elem_per_byte) * bit_rate;
quantized &= (1 << bit_rate) - 1; quantized &= (1 << bit_rate) - 1;

View File

@ -150,8 +150,7 @@ void PyTorchStreamReader::init() {
version, version,
" as Long Long."); " as Long Long.");
} }
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) if (version_ < static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
if (version_ < kMinSupportedFileFormatVersion) {
CAFFE_THROW( CAFFE_THROW(
"Attempted to read a PyTorch file with version ", "Attempted to read a PyTorch file with version ",
c10::to_string(version_), c10::to_string(version_),
@ -161,8 +160,7 @@ void PyTorchStreamReader::init() {
" with latest version of PyTorch to mitigate this issue."); " with latest version of PyTorch to mitigate this issue.");
} }
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) if (version_ > static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
if (version_ > kMaxSupportedFileFormatVersion) {
CAFFE_THROW( CAFFE_THROW(
"Attempted to read a PyTorch file with version ", "Attempted to read a PyTorch file with version ",
version_, version_,

View File

@ -443,7 +443,6 @@ function(torch_compile_options libname)
-Wno-type-limits -Wno-type-limits
-Wno-array-bounds -Wno-array-bounds
-Wno-unknown-pragmas -Wno-unknown-pragmas
-Wno-sign-compare
-Wno-strict-overflow -Wno-strict-overflow
-Wno-strict-aliasing -Wno-strict-aliasing
-Wno-error=deprecated-declarations -Wno-error=deprecated-declarations

View File

@ -718,7 +718,7 @@ FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate(
if (${out_arg}_new_fw_grad_opt.has_value()) { if (${out_arg}_new_fw_grad_opt.has_value()) {
auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value(); auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value();
TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size()); TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size());
for (auto i=0; i<${out_arg}.size(); ++i) { for (const auto i : c10::irange(${out_arg}.size())) {
if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) { if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) {
// The hardcoded 0 here will need to be updated once we support multiple levels. // The hardcoded 0 here will need to be updated once we support multiple levels.
${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace}); ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace});

View File

@ -178,12 +178,10 @@ std::vector<int64_t> ConvTransposeNdImpl<D, Derived>::_output_padding(
ret = at::IntArrayRef(this->options.output_padding()).vec(); ret = at::IntArrayRef(this->options.output_padding()).vec();
} else { } else {
auto k = input.dim() - 2; auto k = input.dim() - 2;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) if (output_size_.value().size() == static_cast<size_t>(k + 2)) {
if (output_size_.value().size() == k + 2) {
output_size_ = output_size_.value().slice(2); output_size_ = output_size_.value().slice(2);
} }
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) if (output_size_.value().size() != static_cast<size_t>(k)) {
if (output_size_.value().size() != k) {
TORCH_CHECK( TORCH_CHECK(
false, false,
"output_size must have ", "output_size must have ",

View File

@ -192,7 +192,7 @@ struct AddGenericMetadata : public MetadataBase {
if (config_ && !config_->experimental_config.performance_events.empty()) { if (config_ && !config_->experimental_config.performance_events.empty()) {
auto& event_names = config_->experimental_config.performance_events; auto& event_names = config_->experimental_config.performance_events;
for (auto i = 0; i < op_event.perf_event_counters_->size(); ++i) { for (const auto i : c10::irange(op_event.perf_event_counters_->size())) {
addMetadata( addMetadata(
event_names[i], event_names[i],
std::to_string((*op_event.perf_event_counters_)[i])); std::to_string((*op_event.perf_event_counters_)[i]));

View File

@ -251,8 +251,7 @@ std::vector<at::Tensor>& scatter_out(
out_tensors[i].device(), out_tensors[i].device(),
"'"); "'");
auto out_sizes = out_tensors[i].sizes().vec(); auto out_sizes = out_tensors[i].sizes().vec();
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) bool same_ndim = out_sizes.size() == static_cast<size_t>(tensor.dim());
bool same_ndim = out_sizes.size() == tensor.dim();
if (same_ndim) { if (same_ndim) {
total_size += out_sizes[dim]; total_size += out_sizes[dim];
chunk_sizes.emplace_back(out_sizes[dim]); chunk_sizes.emplace_back(out_sizes[dim]);

View File

@ -265,7 +265,7 @@ void check_inputs(
int root, int root,
int input_multiplier, int input_multiplier,
int output_multiplier) { int output_multiplier) {
size_t len = inputs.size(); auto len = inputs.size();
if (len <= 0) { if (len <= 0) {
throw std::runtime_error("input sequence can't be empty"); throw std::runtime_error("input sequence can't be empty");
@ -280,7 +280,8 @@ void check_inputs(
check_tensor( check_tensor(
input, input,
i == root ? at::optional<at::Tensor>{output} : at::nullopt, i == static_cast<decltype(i)>(root) ? at::optional<at::Tensor>{output}
: at::nullopt,
input_multiplier, input_multiplier,
output_multiplier, output_multiplier,
numel, numel,
@ -482,7 +483,7 @@ void reduce(
ncclComm_t comm = comms_ref[i]; ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclReduce( NCCL_CHECK(ncclReduce(
inputs[i].data_ptr(), inputs[i].data_ptr(),
root == i ? output.data_ptr() : nullptr, static_cast<decltype(i)>(root) == i ? output.data_ptr() : nullptr,
count, count,
data_type, data_type,
to_nccl_red_op(op), to_nccl_red_op(op),

View File

@ -124,8 +124,7 @@ std::unique_ptr<RpcWithProfilingResp> RpcWithProfilingResp::fromMessage(
for (const auto i : c10::irange( for (const auto i : c10::irange(
kProfileEventsStartIdx, kProfileEventsStartIdx,
kProfileEventsStartIdx + profiledEventsSize)) { kProfileEventsStartIdx + profiledEventsSize)) {
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) TORCH_CHECK(static_cast<size_t>(i) < tupleElements.size());
TORCH_CHECK(i < tupleElements.size());
// Reconstruct remote event from the ivalues. // Reconstruct remote event from the ivalues.
torch::autograd::profiler::LegacyEvent fromIvalueEvent = torch::autograd::profiler::LegacyEvent fromIvalueEvent =
torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]); torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]);

View File

@ -1985,7 +1985,7 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::allgather_coalesced(
invalidArgument("requires non-empty input tensor list"); invalidArgument("requires non-empty input tensor list");
} }
if (output_lists.size() != getSize()) { if (output_lists.size() != static_cast<size_t>(getSize())) {
invalidArgument("output lists should be equal to world size"); invalidArgument("output lists should be equal to world size");
} }
@ -2813,7 +2813,8 @@ void ProcessGroupGloo::monitoredBarrier(
// some ranks have not responded. // some ranks have not responded.
// Ensure all ranks from 1, ... WORLD_SIZE -1 have been successfully // Ensure all ranks from 1, ... WORLD_SIZE -1 have been successfully
// processed. // processed.
auto rankFailure = (processedRanks.size() != size_ - 1); auto rankFailure =
(processedRanks.size() != static_cast<size_t>(size_ - 1));
if (waitAllRanks && rankFailure) { if (waitAllRanks && rankFailure) {
std::vector<int> failedRanks; std::vector<int> failedRanks;
for (const auto i : c10::irange(1, size_)) { for (const auto i : c10::irange(1, size_)) {

View File

@ -774,10 +774,10 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall(
std::vector<at::Tensor>& inputTensors, std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts) { const AllToAllOptions& opts) {
TORCH_CHECK( TORCH_CHECK(
inputTensors.size() == size_, inputTensors.size() == static_cast<size_t>(size_),
"Number of input tensors are not equal to group size"); "Number of input tensors are not equal to group size");
TORCH_CHECK( TORCH_CHECK(
outputTensors.size() == size_, outputTensors.size() == static_cast<size_t>(size_),
"Number of output tensors are not equal to group size"); "Number of output tensors are not equal to group size");
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
[this](std::unique_ptr<WorkEntry>& entry) { [this](std::unique_ptr<WorkEntry>& entry) {

View File

@ -154,7 +154,7 @@ struct CollectiveFingerPrint {
for (const auto i : c10::irange(output_tensors.size())) { for (const auto i : c10::irange(output_tensors.size())) {
const std::vector<at::Tensor> gathered_tensors = output_tensors[i]; const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
const at::Tensor reference_tensor = tensors_to_verify[i]; const at::Tensor reference_tensor = tensors_to_verify[i];
for (int rank = 0; rank < gathered_tensors.size(); rank++) { for (const auto rank : c10::irange(gathered_tensors.size())) {
const auto& rank_tensor = gathered_tensors[rank]; const auto& rank_tensor = gathered_tensors[rank];
if (!rank_tensor.equal(reference_tensor)) { if (!rank_tensor.equal(reference_tensor)) {
CollectiveFingerPrint rank_fingerprint = CollectiveFingerPrint rank_fingerprint =

View File

@ -1035,7 +1035,7 @@ void TCPStore::waitForWorkers() {
auto buf = reinterpret_cast<const char*>(value.data()); auto buf = reinterpret_cast<const char*>(value.data());
auto len = value.size(); auto len = value.size();
int numWorkersCompleted = std::stoi(std::string(buf, len)); int numWorkersCompleted = std::stoi(std::string(buf, len));
if (numWorkersCompleted >= *numWorkers_) { if (numWorkersCompleted >= static_cast<int>(*numWorkers_)) {
break; break;
} }
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>( const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(

View File

@ -2149,7 +2149,7 @@ void verify_params_across_processes(
std::vector<std::vector<at::Tensor>> param_size_output_tensors; std::vector<std::vector<at::Tensor>> param_size_output_tensors;
param_size_output_tensors.emplace_back(); param_size_output_tensors.emplace_back();
auto world_size = process_group->getSize(); auto world_size = process_group->getSize();
for (size_t i = 0; i < world_size; ++i) { for (C10_UNUSED const auto i : c10::irange(world_size)) {
param_size_output_tensors.front().emplace_back( param_size_output_tensors.front().emplace_back(
at::empty_like(param_size_tensor)); at::empty_like(param_size_tensor));
} }
@ -2157,10 +2157,10 @@ void verify_params_across_processes(
std::vector<at::Tensor> param_size_vec{param_size_tensor}; std::vector<at::Tensor> param_size_vec{param_size_tensor};
process_group->allgather(param_size_output_tensors, param_size_vec)->wait(); process_group->allgather(param_size_output_tensors, param_size_vec)->wait();
auto result_size_tensors = param_size_output_tensors.front(); auto result_size_tensors = param_size_output_tensors.front();
for (size_t i = 0; i < world_size; ++i) { for (const auto i : c10::irange(world_size)) {
auto param_size_for_rank = result_size_tensors[i][0].item<int>(); auto param_size_for_rank = result_size_tensors[i][0].item<int>();
TORCH_CHECK( TORCH_CHECK(
param_size_for_rank == params.size(), static_cast<size_t>(param_size_for_rank) == params.size(),
c10::str( c10::str(
"DDP expects same model across all ranks, but Rank ", "DDP expects same model across all ranks, but Rank ",
process_group->getRank(), process_group->getRank(),

View File

@ -502,8 +502,8 @@ std::vector<at::IValue> readWrappedPayload(
payload.resize(indexToRead); payload.resize(indexToRead);
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) static_cast<decltype(additionalPayloadSize)>(payload.size()) >
payload.size() > additionalPayloadSize, additionalPayloadSize,
"Wrong payload sizes: payload.size() is ", "Wrong payload sizes: payload.size() is ",
payload.size(), payload.size(),
" but additional payload size is ", " but additional payload size is ",

View File

@ -84,7 +84,7 @@ ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
GRAPH_DEBUG("Initializing graph input logical tensors"); GRAPH_DEBUG("Initializing graph input logical tensors");
std::map<size_t, int64_t> tensorIdToOccurence = std::map<size_t, int64_t> tensorIdToOccurence =
initializeTensorIdToOccurence(); initializeTensorIdToOccurence();
for (size_t i = 0; i < nGraphInputs_; i++) { for (const auto i : c10::irange(nGraphInputs_)) {
auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]); auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]);
initializedInputIds_.insert(spec.tid()); initializedInputIds_.insert(spec.tid());
int64_t occurence = tensorIdToOccurence[spec.tid()]; int64_t occurence = tensorIdToOccurence[spec.tid()];
@ -95,7 +95,8 @@ ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
initializeConstantInputs(); initializeConstantInputs();
TORCH_CHECK( TORCH_CHECK(
inputSpecs.size() + constantValues_.size() == nPartitionInputs_, inputSpecs.size() + constantValues_.size() ==
static_cast<size_t>(nPartitionInputs_),
"Partition inputs are missing"); "Partition inputs are missing");
GRAPH_DEBUG( GRAPH_DEBUG(
"Concatenating constant input logical tensors to graph input " "Concatenating constant input logical tensors to graph input "
@ -111,7 +112,7 @@ ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
ArgSpecs LlgaKernel::initializeOutputSpecs() const { ArgSpecs LlgaKernel::initializeOutputSpecs() const {
ArgSpecs outputSpecs; ArgSpecs outputSpecs;
outputSpecs.reserve(nOutputs_); outputSpecs.reserve(nOutputs_);
for (size_t i = 0; i < nOutputs_; i++) { for (const auto i : c10::irange(nOutputs_)) {
auto spec = ArgSpec(graph_->outputs()[i]); auto spec = ArgSpec(graph_->outputs()[i]);
if (useOpaqueLayout(i)) { if (useOpaqueLayout(i)) {
spec = spec.any(); spec = spec.any();
@ -126,7 +127,7 @@ std::tuple<RunArgs, RunArgs> LlgaKernel::prepareRunArgs(
TensorArgs& outputs) const { TensorArgs& outputs) const {
RunArgs runInputs, runOutputs; RunArgs runInputs, runOutputs;
auto numInputs = runArgsIdx_.size(); auto numInputs = runArgsIdx_.size();
for (size_t i = 0; i < numInputs; i++) { for (const auto i : c10::irange(numInputs)) {
auto spec = inputSpecs_[i]; auto spec = inputSpecs_[i];
auto input = inputs[runArgsIdx_[i]]; auto input = inputs[runArgsIdx_[i]];
runInputs.push_back( runInputs.push_back(
@ -143,7 +144,7 @@ std::tuple<RunArgs, RunArgs> LlgaKernel::prepareRunArgs(
constantInputs_[i].data_ptr()}); constantInputs_[i].data_ptr()});
} }
for (size_t i = 0; i < nOutputs_; i++) { for (const auto i : c10::irange(nOutputs_)) {
auto spec = outputSpecs_[i]; auto spec = outputSpecs_[i];
auto opt = c10::TensorOptions(spec.aten_scalar_type()).device(device_); auto opt = c10::TensorOptions(spec.aten_scalar_type()).device(device_);
@ -215,7 +216,7 @@ compiled_partition LlgaKernel::compile(const partition& partition) {
// Since layouts of opaque outputs would be known after compilation, // Since layouts of opaque outputs would be known after compilation,
// we need to query them out from compilation and update outputSpecs // we need to query them out from compilation and update outputSpecs
for (size_t i = 0; i < nOutputs_; i++) { for (const auto i : c10::irange(nOutputs_)) {
auto tid = outputSpecs_[i].tid(); auto tid = outputSpecs_[i].tid();
outputSpecs_[i] = compilation.query_logical_tensor(tid); outputSpecs_[i] = compilation.query_logical_tensor(tid);
} }

View File

@ -520,7 +520,7 @@ void IRParser::parseOperator(Block* b) {
const FunctionSchema* schema = n->maybeSchema(); const FunctionSchema* schema = n->maybeSchema();
// Register outputs. // Register outputs.
int idx = 0; unsigned idx = 0;
for (const VarWithType& v : outs) { for (const VarWithType& v : outs) {
vmap[v.name] = n->outputs()[idx]; vmap[v.name] = n->outputs()[idx];
if (schema && !schema->is_varret()) { if (schema && !schema->is_varret()) {

View File

@ -733,7 +733,7 @@ void FlatbufferLoader::extractJitSourceAndConstants(
} }
} }
const auto* jit_constants = module_->jit_constants(); const auto* jit_constants = module_->jit_constants();
for (auto i = 0; i < jit_constants->size(); ++i) { for (const auto i : c10::irange(jit_constants->size())) {
constants->emplace_back(getIValue(jit_constants->Get(i))); constants->emplace_back(getIValue(jit_constants->Get(i)));
} }
parseExtraFilesFromVector(module_->jit_sources(), jit_sources); parseExtraFilesFromVector(module_->jit_sources(), jit_sources);

View File

@ -68,7 +68,7 @@ bool Function::initialize_operators(bool should_check_operators) {
std::unordered_set<std::string> unsupported_op_names; std::unordered_set<std::string> unsupported_op_names;
code_.operators_.resize(code_.op_names_.size()); code_.operators_.resize(code_.op_names_.size());
bool all_ops_supported = true; bool all_ops_supported = true;
for (int i = 0; i < code_.op_names_.size(); i++) { for (unsigned i = 0; i < code_.op_names_.size(); i++) {
const auto& opname = code_.op_names_[i]; const auto& opname = code_.op_names_[i];
int num_args = code_.operator_input_sizes_[i]; int num_args = code_.operator_input_sizes_[i];
c10::optional<int> num_specified_args = c10::optional<int> num_specified_args =
@ -212,7 +212,7 @@ c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
stack.pop_back(); stack.pop_back();
} }
TORCH_CHECK( TORCH_CHECK(
num_specified_args.value() >= out_args.size(), static_cast<size_t>(num_specified_args.value()) >= out_args.size(),
"The number of output arguments is: ", "The number of output arguments is: ",
out_args.size(), out_args.size(),
", which is more then the number of specified arguments: ", ", which is more then the number of specified arguments: ",

View File

@ -196,8 +196,7 @@ bool InterpreterState::run(Stack& stack) {
auto userObj = pop(stack).toObject(); auto userObj = pop(stack).toObject();
// Mobile only: since the number of slots is not known, resize the // Mobile only: since the number of slots is not known, resize the
// numAttributes before setSlot. // numAttributes before setSlot.
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) while (static_cast<int>(userObj->type()->numAttributes()) <= inst.X) {
while (userObj->type()->numAttributes() <= inst.X) {
std::stringstream ss; std::stringstream ss;
ss << userObj->type()->numAttributes(); ss << userObj->type()->numAttributes();
userObj->type()->addAttribute(ss.str(), c10::NoneType::get()); userObj->type()->addAttribute(ss.str(), c10::NoneType::get());

View File

@ -63,7 +63,7 @@ bool Module::compareMethodSchemas(
void Module::unsafeRemoveMethod(const std::string& basename) { void Module::unsafeRemoveMethod(const std::string& basename) {
int64_t i = 0; int64_t i = 0;
for (; i < cu_->methods().size(); ++i) { for (; i < static_cast<int64_t>(cu_->methods().size()); ++i) {
if ((cu_->methods()[i])->name() == basename) { if ((cu_->methods()[i])->name() == basename) {
break; break;
} }
@ -334,7 +334,7 @@ TORCH_API ModuleInfo get_module_info(const mobile::Module& module) {
std::vector<std::string> type_name_list; std::vector<std::string> type_name_list;
for (const auto& func_ptr : module.compilation_unit().methods()) { for (const auto& func_ptr : module.compilation_unit().methods()) {
const auto& function = *func_ptr; const auto& function = *func_ptr;
for (int i = 0; i < function.get_code().op_names_.size(); i++) { for (const auto i : c10::irange(function.get_code().op_names_.size())) {
const auto& op = function.get_code().op_names_[i]; const auto& op = function.get_code().op_names_[i];
minfo.opname_to_num_args[mobile::operator_str(op)] = minfo.opname_to_num_args[mobile::operator_str(op)] =
function.get_code().operator_input_sizes_[i]; function.get_code().operator_input_sizes_[i];

View File

@ -53,7 +53,7 @@ std::vector<mobile::nnc::InputSpec> toInputSpecs(
auto num_inputs = auto num_inputs =
g->inputs().size() - kernel->getSymbolicShapeInputs().size(); g->inputs().size() - kernel->getSymbolicShapeInputs().size();
for (int i = 0; i < num_inputs; i++) { for (const auto i : c10::irange(num_inputs)) {
auto v = g->inputs()[i]; auto v = g->inputs()[i];
const auto& t = v->type(); const auto& t = v->type();
mobile::nnc::InputSpec spec; mobile::nnc::InputSpec spec;
@ -373,7 +373,7 @@ std::vector<c10::optional<at::Tensor>> generateExampleInputs(
const std::vector<at::MemoryFormat>& inputMemoryFormats) { const std::vector<at::MemoryFormat>& inputMemoryFormats) {
std::vector<c10::optional<at::Tensor>> example_inputs; std::vector<c10::optional<at::Tensor>> example_inputs;
example_inputs.reserve(inputShapes.size()); example_inputs.reserve(inputShapes.size());
for (int i = 0; i < inputShapes.size(); ++i) { for (const auto i : c10::irange(inputShapes.size())) {
const auto dtype = at::dtype(inputTypes[i]); const auto dtype = at::dtype(inputTypes[i]);
const auto memory_format = inputMemoryFormats[i]; const auto memory_format = inputMemoryFormats[i];
example_inputs.emplace_back( example_inputs.emplace_back(

View File

@ -45,7 +45,7 @@ bool InputSpec::validate(const at::Tensor& input) const {
return false; return false;
} }
auto spec_sizes = sizes_; auto spec_sizes = sizes_;
for (int i = 0; i < spec_sizes.size(); i++) { for (const auto i : c10::irange(spec_sizes.size())) {
// InputSpec size 0 means that the dimension is dynamic // InputSpec size 0 means that the dimension is dynamic
if (spec_sizes[i] != 0 && spec_sizes[i] != input.sizes()[i]) { if (spec_sizes[i] != 0 && spec_sizes[i] != input.sizes()[i]) {
return false; return false;

View File

@ -89,8 +89,8 @@ void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
// algorithm, because the number of upgrader per operator will be just a // algorithm, because the number of upgrader per operator will be just a
// few and tend to keep the code light-weight from binary size concern. // few and tend to keep the code light-weight from binary size concern.
for (const auto& upgrader : upgrader_list) { for (const auto& upgrader : upgrader_list) {
if (operator_version <= upgrader.max_version && if (static_cast<int>(operator_version) <= upgrader.max_version &&
operator_version >= upgrader.min_version) { static_cast<int>(operator_version) >= upgrader.min_version) {
// If there exists a valid upgrader, change the instruction OP to // If there exists a valid upgrader, change the instruction OP to
// CALL, and the index will point to the according upgrader // CALL, and the index will point to the according upgrader
// function. All upgrader function are available in // function. All upgrader function are available in
@ -102,7 +102,7 @@ void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
// new_inst.X = upgrader.index; // new_inst.X = upgrader.index;
// code->instructions_[i] = new_inst; // code->instructions_[i] = new_inst;
TORCH_CHECK( TORCH_CHECK(
upgrader.index < code.functions_.size(), upgrader.index < static_cast<int>(code.functions_.size()),
"upgrader index is, ", "upgrader index is, ",
upgrader.index, upgrader.index,
" and it's larger than the upgrader function list length ", " and it's larger than the upgrader function list length ",

View File

@ -22,7 +22,7 @@ c10::optional<UpgraderEntry> findUpgrader(
upgraders_for_schema.begin(), upgraders_for_schema.begin(),
upgraders_for_schema.end(), upgraders_for_schema.end(),
[current_version](const UpgraderEntry& entry) { [current_version](const UpgraderEntry& entry) {
return entry.bumped_at_version > current_version; return entry.bumped_at_version > static_cast<int>(current_version);
}); });
if (pos != upgraders_for_schema.end()) { if (pos != upgraders_for_schema.end()) {
@ -36,7 +36,7 @@ bool isOpCurrentBasedOnUpgraderEntries(
size_t current_version) { size_t current_version) {
auto latest_update = auto latest_update =
upgraders_for_schema[upgraders_for_schema.size() - 1].bumped_at_version; upgraders_for_schema[upgraders_for_schema.size() - 1].bumped_at_version;
if (latest_update > current_version) { if (latest_update > static_cast<int>(current_version)) {
return false; return false;
} }
return true; return true;

View File

@ -504,7 +504,7 @@ size_t determineUsageIdx(Value* value, Node* user) {
const auto idx = const auto idx =
std::find(user->inputs().begin(), user->inputs().end(), value) - std::find(user->inputs().begin(), user->inputs().end(), value) -
user->inputs().begin(); user->inputs().begin();
TORCH_CHECK(idx != user->inputs().size()); TORCH_CHECK(idx != static_cast<decltype(idx)>(user->inputs().size()));
return idx; return idx;
} }

View File

@ -80,7 +80,7 @@ bool isZerodimCPUTensor(std::shared_ptr<TensorType> tensor_type) {
bool propWithNoDevice(Node* n) { bool propWithNoDevice(Node* n) {
// Propagate if we can verify that all input devices match, // Propagate if we can verify that all input devices match,
// except CPU zerodim, which any other type can overwrite // except CPU zerodim, which any other type can overwrite
int input_num = 0; size_t input_num = 0;
for (; input_num < n->inputs().size(); input_num++) { for (; input_num < n->inputs().size(); input_num++) {
if (n->inputs()[input_num]->type()->cast<TensorType>()) { if (n->inputs()[input_num]->type()->cast<TensorType>()) {
@ -124,7 +124,7 @@ bool defaultDeviceProp(Node* n) {
return false; return false;
} }
auto arguments = schema->arguments(); auto arguments = schema->arguments();
for (int i = 0; i < arguments.size(); i++) { for (size_t i = 0; i < arguments.size(); i++) {
Argument& argument = arguments[i]; Argument& argument = arguments[i];
if (DeviceObjType::get()->isSubtypeOf(argument.type())) { if (DeviceObjType::get()->isSubtypeOf(argument.type())) {
// Optional args are filled in by torchscript with default val // Optional args are filled in by torchscript with default val

View File

@ -52,7 +52,7 @@ void insertPrePackedConvOpForNode(Node* n) {
Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"), 1); Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"), 1);
// skip input value // skip input value
for (auto i = 1; i < n->inputs().size(); i++) { for (const auto i : c10::irange(1, n->inputs().size())) {
Value* v = n->input(i); Value* v = n->input(i);
prepack_node->addInput(v); prepack_node->addInput(v);
} }

View File

@ -1255,8 +1255,8 @@ class ShapePropagator : public PropertyPropBase {
} else if (upcast_integer && !at::isFloatingType(*type->scalarType())) { } else if (upcast_integer && !at::isFloatingType(*type->scalarType())) {
type = type->withScalarType(at::kLong); type = type->withScalarType(at::kLong);
} }
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) if (static_cast<int64_t>(*type->dim()) >= num_reduced_dim &&
if (*type->dim() >= num_reduced_dim && num_reduced_dim > 0) { num_reduced_dim > 0) {
return {type->withDim(*type->dim() - num_reduced_dim)}; return {type->withDim(*type->dim() - num_reduced_dim)};
} else { } else {
return {std::move(type)}; return {std::move(type)};

View File

@ -156,7 +156,7 @@ std::ostream& operator<<(std::ostream& os, const ShapeArguments& sa) {
} }
os << "("; os << "(";
for (size_t i = 0; i < sa.len(); i++) { for (const auto i : c10::irange(sa.len())) {
os << sa.at(i); os << sa.at(i);
} }
os << ")"; os << ")";
@ -422,8 +422,8 @@ struct SymbolicShapeOpAnalyzer {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
inputs_.size() >= shape_compute_graph_->inputs().size(), inputs_.size() >= shape_compute_graph_->inputs().size(),
"Missing Arg for Shape Graph"); "Missing Arg for Shape Graph");
for (int64_t index = 0; index < shape_compute_graph_->inputs().size(); for (const auto index :
index++) { c10::irange(shape_compute_graph_->inputs().size())) {
auto shape_arguments = c10::get_if<ShapeArguments>(&inputs_[index]); auto shape_arguments = c10::get_if<ShapeArguments>(&inputs_[index]);
if (!shape_arguments || !shape_arguments->has_dim()) { if (!shape_arguments || !shape_arguments->has_dim()) {
continue; continue;
@ -551,7 +551,7 @@ struct SymbolicShapeOpAnalyzer {
std::vector<c10::SymbolicShape> propagateShapesInGraph() { std::vector<c10::SymbolicShape> propagateShapesInGraph() {
bool made_change = true; bool made_change = true;
constexpr size_t MAX_ATTEMPTS = 8; constexpr size_t MAX_ATTEMPTS = 8;
for (int attempt_num = 0; made_change && attempt_num < MAX_ATTEMPTS; for (unsigned attempt_num = 0; made_change && attempt_num < MAX_ATTEMPTS;
attempt_num++) { attempt_num++) {
// symbolic shape concrete values are only used in final shape extraction // symbolic shape concrete values are only used in final shape extraction
GRAPH_DUMP("Before substitution: ", shape_compute_graph_); GRAPH_DUMP("Before substitution: ", shape_compute_graph_);
@ -793,7 +793,7 @@ c10::SymbolicShape combine_bounds(
return c10::SymbolicShape(); return c10::SymbolicShape();
} }
std::vector<c10::ShapeSymbol> merged_shapes; std::vector<c10::ShapeSymbol> merged_shapes;
for (int i = 0; i < lower_bound.rank(); i++) { for (const auto i : c10::irange(*lower_bound.rank())) {
// TODO: Merge equivalent expressions (not needed for current use case) // TODO: Merge equivalent expressions (not needed for current use case)
if (lower_bound[i] == upper_bound[i]) { if (lower_bound[i] == upper_bound[i]) {
merged_shapes.push_back(lower_bound[i]); merged_shapes.push_back(lower_bound[i]);

View File

@ -522,8 +522,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
INST_NEXT; INST_NEXT;
case INST(TYPECHECK): { case INST(TYPECHECK): {
INST_GUARD; INST_GUARD;
int num_inputs = inst.N, i = 0; unsigned num_inputs = inst.N, i = 0;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0); TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
// Check every input's shape against profiled (expected) shape. // Check every input's shape against profiled (expected) shape.
for (i = 0; i < num_inputs; i++) { for (i = 0; i < num_inputs; i++) {

View File

@ -865,8 +865,8 @@ struct CodeImpl {
static_cast<int64_t>(schema.arguments().size()) + static_cast<int64_t>(schema.arguments().size()) +
static_cast<int64_t>(schema.returns().size()); static_cast<int64_t>(schema.returns().size());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
expected_size == actual_size || schema.is_varret() || static_cast<size_t>(expected_size) == actual_size ||
schema.is_vararg(), schema.is_varret() || schema.is_vararg(),
"Expected to find ", "Expected to find ",
expected_size, expected_size,
" values on the stack, but found ", " values on the stack, but found ",

View File

@ -407,7 +407,7 @@ void StandardMemoryPlanner::allocateManagedTensors() {
auto* start = allocateBuffer(managed_bytes_); auto* start = allocateBuffer(managed_bytes_);
reused_tensors_ = 0; reused_tensors_ = 0;
auto group_idx = 0; size_t group_idx = 0;
for (const size_t storages_idx : c10::irange(storages_.size())) { for (const size_t storages_idx : c10::irange(storages_.size())) {
auto tensor_size = storages_nbytes_[storages_idx]; auto tensor_size = storages_nbytes_[storages_idx];
if (tensor_size == 0) { if (tensor_size == 0) {
@ -444,7 +444,7 @@ void StandardMemoryPlanner::deallocateManagedTensors() {
// We don't have any guarantee that the model doesn't change the // We don't have any guarantee that the model doesn't change the
// Storage for managed tensors out from under us during execution, // Storage for managed tensors out from under us during execution,
// so we have to check the Storages each time we deallocate. // so we have to check the Storages each time we deallocate.
auto group_idx = 0; unsigned group_idx = 0;
const bool first_time = storages_.empty(); const bool first_time = storages_.empty();
if (C10_UNLIKELY(first_time)) { if (C10_UNLIKELY(first_time)) {
if (storages_.is_allocated()) { if (storages_.is_allocated()) {

View File

@ -127,7 +127,9 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
return nullptr; return nullptr;
} }
return [](ProcessedNode* p_node) { return [](ProcessedNode* p_node) {
DCHECK(p_node->num_inputs() - 1 == p_node->outputs().size()); DCHECK(
static_cast<size_t>(p_node->num_inputs() - 1) ==
p_node->outputs().size());
auto dict = p_node->Input(0).toGenericDict(); auto dict = p_node->Input(0).toGenericDict();
const auto num_inputs = p_node->num_inputs(); const auto num_inputs = p_node->num_inputs();
for (size_t i = 1; i < num_inputs; ++i) { for (size_t i = 1; i < num_inputs; ++i) {
@ -1252,7 +1254,8 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
const auto num_elems = elems.size(); const auto num_elems = elems.size();
const auto idx = pnode->Input(1).toInt(); const auto idx = pnode->Input(1).toInt();
const auto norm_idx = normalizeIndex(idx, num_elems); const auto norm_idx = normalizeIndex(idx, num_elems);
if (norm_idx < 0 || norm_idx >= num_elems) { if (norm_idx < 0 ||
norm_idx >= static_cast<decltype(norm_idx)>(num_elems)) {
// Use std::runtime_error instead of c10::Error to be consistent with // Use std::runtime_error instead of c10::Error to be consistent with
// JIT // JIT
throw std::out_of_range("Tuple index out of range"); throw std::out_of_range("Tuple index out of range");

View File

@ -217,7 +217,7 @@ void einsum(Stack& stack, size_t num_inputs) {
void percentFormat(Stack& stack, size_t num_inputs) { void percentFormat(Stack& stack, size_t num_inputs) {
auto format_str = peek(stack, 0, num_inputs).toStringRef(); auto format_str = peek(stack, 0, num_inputs).toStringRef();
auto args = last(stack, num_inputs - 1)[0]; auto args = last(stack, num_inputs - 1)[0];
auto args_size = 1; // assumed size size_t args_size = 1; // assumed size
if (args.isTuple()) { if (args.isTuple()) {
args_size = args.toTupleRef().elements().size(); args_size = args.toTupleRef().elements().size();
} }
@ -239,7 +239,6 @@ void percentFormat(Stack& stack, size_t num_inputs) {
begin = percent_idx + 2; // skip the `%` and the format specifier begin = percent_idx + 2; // skip the `%` and the format specifier
continue; continue;
} }
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
TORCH_CHECK(used_args < args_size, "Too few arguments for format string"); TORCH_CHECK(used_args < args_size, "Too few arguments for format string");
char key = format_str.at(format_idx); char key = format_str.at(format_idx);
IValue arg; IValue arg;
@ -252,7 +251,6 @@ void percentFormat(Stack& stack, size_t num_inputs) {
begin = percent_idx + 2; begin = percent_idx + 2;
++used_args; ++used_args;
} }
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
TORCH_CHECK(used_args == args_size, "Too many arguments for format string"); TORCH_CHECK(used_args == args_size, "Too many arguments for format string");
drop(stack, num_inputs); drop(stack, num_inputs);
push(stack, ss.str()); push(stack, ss.str());

View File

@ -503,7 +503,8 @@ GraphEncoder::GraphEncoder(
model_proto_.set_producer_name("pytorch"); model_proto_.set_producer_name("pytorch");
TORCH_CHECK( TORCH_CHECK(
onnx_opset_version > 0 && onnx_opset_version > 0 &&
onnx_opset_version < kOpsetVersionToIRVersion.size() && static_cast<size_t>(onnx_opset_version) <
kOpsetVersionToIRVersion.size() &&
kOpsetVersionToIRVersion[onnx_opset_version] != kInvalidOpsetVersion, kOpsetVersionToIRVersion[onnx_opset_version] != kInvalidOpsetVersion,
"Unsupported onnx_opset_version: ", "Unsupported onnx_opset_version: ",
onnx_opset_version); onnx_opset_version);

View File

@ -269,7 +269,7 @@ IValue convertMobileFunctionToCodeTable(
std::vector<IValue> operators; std::vector<IValue> operators;
operators.reserve(code.op_names_.size()); operators.reserve(code.op_names_.size());
for (int i = 0; i < code.op_names_.size(); ++i) { for (unsigned i = 0; i < code.op_names_.size(); ++i) {
const auto& opname = code.op_names_[i]; const auto& opname = code.op_names_[i];
const int size = code.operator_input_sizes_[i]; const int size = code.operator_input_sizes_[i];
if (compilation_options.enable_default_value_for_unspecified_arg) { if (compilation_options.enable_default_value_for_unspecified_arg) {

View File

@ -176,7 +176,7 @@ std::pair<IValue, IValue> getFunctionTuple(
// operators // operators
std::vector<IValue> operators; std::vector<IValue> operators;
operators.reserve(mobile_code.op_names_.size()); operators.reserve(mobile_code.op_names_.size());
for (int i = 0; i < mobile_code.op_names_.size(); ++i) { for (const auto i : c10::irange(mobile_code.op_names_.size())) {
const auto& opname = mobile_code.op_names_[i]; const auto& opname = mobile_code.op_names_[i];
const int size = mobile_code.operator_input_sizes_[i]; const int size = mobile_code.operator_input_sizes_[i];
if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) { if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {

View File

@ -235,7 +235,7 @@ flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
std::vector<flatbuffers::Offset<mobile::serialization::Operator>> std::vector<flatbuffers::Offset<mobile::serialization::Operator>>
operator_vector; operator_vector;
operator_vector.reserve(code.op_names_.size()); operator_vector.reserve(code.op_names_.size());
for (int i = 0; i < code.op_names_.size(); ++i) { for (const auto i : c10::irange(code.op_names_.size())) {
const auto& opname = code.op_names_[i]; const auto& opname = code.op_names_[i];
const int op_size = code.operator_input_sizes_[i]; const int op_size = code.operator_input_sizes_[i];
operator_vector.push_back(CreateOperator( operator_vector.push_back(CreateOperator(

View File

@ -923,7 +923,7 @@ struct PythonPrintImpl {
if (val.isString()) { if (val.isString()) {
const auto maxASCII = 0x7fu; const auto maxASCII = 0x7fu;
for (auto c : val.toStringRef()) { for (auto c : val.toStringRef()) {
if (c > maxASCII) { if (static_cast<decltype(maxASCII)>(c) > maxASCII) {
hasNonASCII = true; hasNonASCII = true;
return true; return true;
} }
@ -1235,7 +1235,7 @@ struct PythonPrintImpl {
auto num_necessary = specified_args.first; auto num_necessary = specified_args.first;
auto num_out = specified_args.second; auto num_out = specified_args.second;
for (size_t i = 0; i < num_necessary; ++i) { for (const auto i : c10::irange(static_cast<size_t>(num_necessary))) {
if (i > 0) if (i > 0)
stmt << ", "; stmt << ", ";
auto v = useOf(node->inputs().at(i)); auto v = useOf(node->inputs().at(i));
@ -1745,7 +1745,7 @@ void jitModuleToPythonCodeAndConstants(
class_deps.add(class_type); class_deps.add(class_type);
} }
for (int i = 0; i < class_deps.size(); ++i) { for (const auto i : c10::irange(class_deps.size())) {
auto type = class_deps[i]; auto type = class_deps[i];
auto qualname = uniquer.getUniqueName(type); auto qualname = uniquer.getUniqueName(type);
std::string qualifier = qualname.prefix(); std::string qualifier = qualname.prefix();

View File

@ -119,9 +119,8 @@ void CudaAnalysis::visit(ForPtr v) {
throw std::runtime_error("support only 3D gpu_block_index"); throw std::runtime_error("support only 3D gpu_block_index");
} }
ExprPtr prev = nullptr; ExprPtr prev = nullptr;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
// NOLINTNEXTLINE(bugprone-branch-clone) // NOLINTNEXTLINE(bugprone-branch-clone)
if (gpu_block_extents_.size() <= gpu_block_index) { if (gpu_block_extents_.size() <= static_cast<size_t>(gpu_block_index)) {
gpu_block_extents_.resize(gpu_block_index + 1); gpu_block_extents_.resize(gpu_block_index + 1);
} else { } else {
prev = gpu_block_extents_[gpu_block_index]; prev = gpu_block_extents_[gpu_block_index];
@ -149,9 +148,8 @@ void CudaAnalysis::visit(ForPtr v) {
throw std::runtime_error("support only 3D gpu_thread_index"); throw std::runtime_error("support only 3D gpu_thread_index");
} }
ExprPtr prev = nullptr; ExprPtr prev = nullptr;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
// NOLINTNEXTLINE(bugprone-branch-clone) // NOLINTNEXTLINE(bugprone-branch-clone)
if (gpu_thread_extents_.size() <= gpu_thread_index) { if (gpu_thread_extents_.size() <= static_cast<size_t>(gpu_thread_index)) {
gpu_thread_extents_.resize(gpu_thread_index + 1); gpu_thread_extents_.resize(gpu_thread_index + 1);
} else { } else {
prev = gpu_thread_extents_[gpu_thread_index]; prev = gpu_thread_extents_[gpu_thread_index];
@ -503,8 +501,7 @@ class PrioritizeLoad : public IRMutator {
v->indices().size() == nested_store_->indices().size()) { v->indices().size() == nested_store_->indices().size()) {
// also check indices // also check indices
bool same = true; bool same = true;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) for (const auto i : c10::irange(v->indices().size())) {
for (int i = 0; i < v->indices().size(); ++i) {
if (!exprEquals(v->indices()[i], nested_store_->indices()[i])) { if (!exprEquals(v->indices()[i], nested_store_->indices()[i])) {
same = false; same = false;
break; break;

View File

@ -688,7 +688,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
throw malformed_input( throw malformed_input(
"Number of dimensions did not match number of strides", buf); "Number of dimensions did not match number of strides", buf);
} }
size_t buf_size = 1; int64_t buf_size = 1;
if (!dims.empty()) { if (!dims.empty()) {
ExprHandle buf_size_expr = ExprHandle(immLike(dims[0], 1)); ExprHandle buf_size_expr = ExprHandle(immLike(dims[0], 1));
ExprHandle negative_one = ExprHandle(immLike(dims[0], -1)); ExprHandle negative_one = ExprHandle(immLike(dims[0], -1));

View File

@ -427,7 +427,7 @@ bool trimGraphOnce(const std::shared_ptr<Graph>& graph) {
std::unordered_set<Value*> outputs( std::unordered_set<Value*> outputs(
graph->outputs().begin(), graph->outputs().end()); graph->outputs().begin(), graph->outputs().end());
bool changed = false; bool changed = false;
for (int idx = 0; idx < ret->inputs().size(); idx++) { for (size_t idx = 0; idx < ret->inputs().size(); idx++) {
auto v = ret->inputs()[idx]; auto v = ret->inputs()[idx];
if (graph_inputs.count(v)) { if (graph_inputs.count(v)) {
continue; continue;

View File

@ -1089,7 +1089,7 @@ std::vector<ExprHandle> TensorExprKernel::getInputStrides(
generated_strides++; generated_strides++;
} }
} }
for (int i = 0; i < rank; i++) { for (int i = 0; i < static_cast<int>(rank); i++) {
if (stride_input[i] == torch::jit::StrideInput::S_TRAN_CONT && if (stride_input[i] == torch::jit::StrideInput::S_TRAN_CONT &&
stride_set[i - 1]) { stride_set[i - 1]) {
inputTensorStrides[i] = inputTensorStrides[i] =
@ -1500,7 +1500,7 @@ BlockPtr TensorExprKernel::bindAllInputs() {
// //
// TODO: Check if the tensors with symbolic shapes are contiguous. // TODO: Check if the tensors with symbolic shapes are contiguous.
TORCH_CHECK( TORCH_CHECK(
nInputs_ > symbolic_shape_inputs_.size(), nInputs_ > static_cast<int64_t>(symbolic_shape_inputs_.size()),
"Symbolic dims not provided as inputs to the graph"); "Symbolic dims not provided as inputs to the graph");
// First, process the symbolic input params and create a new variable for // First, process the symbolic input params and create a new variable for
@ -1510,7 +1510,9 @@ BlockPtr TensorExprKernel::bindAllInputs() {
// create for the symbolic input params. // create for the symbolic input params.
symbolic_shape_args.reserve(symbolic_shape_inputs_.size()); symbolic_shape_args.reserve(symbolic_shape_inputs_.size());
for (size_t i = symbolic_shape_inputs_start_pos; i < nInputs_; ++i) { for (size_t i = symbolic_shape_inputs_start_pos;
i < static_cast<size_t>(nInputs_);
++i) {
auto input = graph_->inputs()[i]; auto input = graph_->inputs()[i];
if (input->type()->kind() != TypeKind::IntType) { if (input->type()->kind() != TypeKind::IntType) {
throw std::runtime_error( throw std::runtime_error(
@ -2104,7 +2106,8 @@ void TensorExprKernel::runWithAllocatedOutputs(Stack& stack) const {
args.emplace_back(&stride_values[idx]); args.emplace_back(&stride_values[idx]);
} }
TORCH_INTERNAL_ASSERT(nOutputs_ == bufOutputs_.size()); TORCH_INTERNAL_ASSERT(
nOutputs_ == static_cast<int64_t>(bufOutputs_.size()));
for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) { for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
auto& out = stack_outputs[i].toTensor(); auto& out = stack_outputs[i].toTensor();
// This has only been tested on CPUs. // This has only been tested on CPUs.

View File

@ -68,7 +68,7 @@ std::tuple<std::vector<T>, std::vector<int>> select_n_randomly(
std::vector<T> selected_objects; std::vector<T> selected_objects;
std::vector<int> selected_indices; std::vector<int> selected_indices;
if (indices.size() < n) { if (static_cast<int>(indices.size()) < n) {
return std::make_tuple(selected_objects, selected_indices); return std::make_tuple(selected_objects, selected_indices);
} }
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
@ -393,8 +393,8 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
// Find pairs of axes that can be reordered // Find pairs of axes that can be reordered
std::vector<std::pair<ForPtr, ForPtr>> valid_pairs; std::vector<std::pair<ForPtr, ForPtr>> valid_pairs;
for (int i = 0; i < loops.size(); i++) { for (const auto i : c10::irange(loops.size())) {
for (int j = i + 1; j < loops.size(); j++) { for (const auto j : c10::irange(i + 1, loops.size())) {
if (LoopNest::findOuterFor(loops[i], loops[j])) { if (LoopNest::findOuterFor(loops[i], loops[j])) {
valid_pairs.emplace_back(loops[i], loops[j]); valid_pairs.emplace_back(loops[i], loops[j]);
} }

View File

@ -338,7 +338,7 @@ Tensor computeChunk(
size_t step = buf_info->dims[norm_dim] / chunks; size_t step = buf_info->dims[norm_dim] / chunks;
std::vector<ExprHandle> new_indices; std::vector<ExprHandle> new_indices;
for (int64_t i = 0; i < indices.size(); ++i) { for (int64_t i = 0; i < static_cast<int64_t>(indices.size()); ++i) {
if (i == norm_dim) { if (i == norm_dim) {
new_indices.push_back( new_indices.push_back(
indices[i] + ExprHandle(immLike(indices[i], chunkIdx * step))); indices[i] + ExprHandle(immLike(indices[i], chunkIdx * step)));
@ -574,7 +574,7 @@ Tensor computeCatWoConditionals(
std::vector<VarPtr> for_vars(dims.size()); std::vector<VarPtr> for_vars(dims.size());
std::vector<ExprPtr> load_indices(dims.size()); std::vector<ExprPtr> load_indices(dims.size());
std::vector<ExprPtr> store_indices(dims.size()); std::vector<ExprPtr> store_indices(dims.size());
for (int64_t i = 0; i < dims.size(); ++i) { for (int64_t i = 0; i < static_cast<int64_t>(dims.size()); ++i) {
for_vars[i] = alloc<Var>( for_vars[i] = alloc<Var>(
"i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), "i" + c10::to_string(inp_pos) + "_" + c10::to_string(i),
dims[i].dtype()); dims[i].dtype());

View File

@ -126,8 +126,7 @@ Tensor computeMean(
extra_args = c10::fmap<ExprHandle>(*mean_dims); extra_args = c10::fmap<ExprHandle>(*mean_dims);
} else { } else {
// When dims argument is not specified, reduce over all dimensions // When dims argument is not specified, reduce over all dimensions
// NOLINTNEXTLINE(clang-diagnostic-sign-compare) for (int64_t idx = 0; idx < static_cast<int64_t>(InputBuf.ndim()); ++idx) {
for (int64_t idx = 0; idx < InputBuf.ndim(); idx++) {
extra_args.emplace_back(idx); extra_args.emplace_back(idx);
} }
} }

View File

@ -15,7 +15,8 @@ std::vector<int64_t> DropDimensions(
std::vector<int64_t> new_dims; std::vector<int64_t> new_dims;
size_t drop_index = 0; size_t drop_index = 0;
for (const auto i : c10::irange(sizes.size())) { for (const auto i : c10::irange(sizes.size())) {
if (drop_index < drop_dims.size() && i == drop_dims[drop_index]) { if (drop_index < drop_dims.size() &&
static_cast<int64_t>(i) == drop_dims[drop_index]) {
++drop_index; ++drop_index;
} else { } else {
new_dims.push_back(sizes[i]); new_dims.push_back(sizes[i]);

View File

@ -668,7 +668,7 @@ std::vector<torch::lazy::BackendDataPtr> LazyGraphExecutor::SetTensorData(
const std::vector<BackendDataPtr>& tensor_data_vec) { const std::vector<BackendDataPtr>& tensor_data_vec) {
std::vector<BackendDataPtr> tensors_data; std::vector<BackendDataPtr> tensors_data;
tensors_data.reserve(indices.size()); tensors_data.reserve(indices.size());
for (int i = 0; i < indices.size(); i++) { for (const auto i : c10::irange(indices.size())) {
auto index = indices[i]; auto index = indices[i];
LazyTensorPtr& tensor = (*tensors)[index]; LazyTensorPtr& tensor = (*tensors)[index];
// If the config.force_ltc_data flag is true, the purpose of this tensor // If the config.force_ltc_data flag is true, the purpose of this tensor
@ -783,7 +783,8 @@ LazyGraphExecutor::CompilationResult LazyGraphExecutor::Compile(
// TODO(whc) should computation be allowed null here? (because it is in one // TODO(whc) should computation be allowed null here? (because it is in one
// case) // case)
TORCH_CHECK( TORCH_CHECK(
computation->parameters_size() == po_data->parameters_data.size()); computation->parameters_size() ==
static_cast<int>(po_data->parameters_data.size()));
} }
return { return {

View File

@ -321,7 +321,7 @@ std::string MetricFnValue(double value) {
std::string MetricFnBytes(double value) { std::string MetricFnBytes(double value) {
static const std::array<const char*, 6> kSizeSuffixes{ static const std::array<const char*, 6> kSizeSuffixes{
"B", "KB", "MB", "GB", "TB", "PB"}; "B", "KB", "MB", "GB", "TB", "PB"};
int sfix = 0; unsigned sfix = 0;
for (; (sfix + 1) < kSizeSuffixes.size() && value >= 1024.0; ++sfix) { for (; (sfix + 1) < kSizeSuffixes.size() && value >= 1024.0; ++sfix) {
value /= 1024.0; value /= 1024.0;
} }

View File

@ -82,7 +82,8 @@ std::vector<int64_t> BuildSqueezedDimensions(
std::vector<int64_t> output_dimensions; std::vector<int64_t> output_dimensions;
for (const auto i : c10::irange(dimensions.size())) { for (const auto i : c10::irange(dimensions.size())) {
int64_t dim = dimensions[i]; int64_t dim = dimensions[i];
if (dim != 1 || (i != squeeze_dim && squeeze_dim >= 0)) { if (dim != 1 ||
(static_cast<int64_t>(i) != squeeze_dim && squeeze_dim >= 0)) {
output_dimensions.push_back(dim); output_dimensions.push_back(dim);
} }
} }

View File

@ -76,7 +76,7 @@ c10::SymbolicShape get_symbolic_shape(at::Tensor& tensor) {
sizes.size() == is_symbolic->size(), sizes.size() == is_symbolic->size(),
"Dims of two values are not consistent"); "Dims of two values are not consistent");
std::vector<c10::optional<int64_t>> symbolic_dims; std::vector<c10::optional<int64_t>> symbolic_dims;
for (int64_t i = 0; i < sizes.size(); i++) { for (size_t i = 0; i < sizes.size(); i++) {
if (is_symbolic->at(i)) { if (is_symbolic->at(i)) {
symbolic_dims.emplace_back(c10::nullopt); symbolic_dims.emplace_back(c10::nullopt);
} else { } else {
@ -120,7 +120,7 @@ void applySymbolicShapesOnLT(
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
res_symbolic->size() == result_shapes.size(), res_symbolic->size() == result_shapes.size(),
"Result shape size is not consistent"); "Result shape size is not consistent");
for (int64_t i = 0; i < res_symbolic->size(); i++) { for (size_t i = 0; i < res_symbolic->size(); i++) {
auto sym_dims = res_symbolic->at(i).symbolicDims(); auto sym_dims = res_symbolic->at(i).symbolicDims();
if (sym_dims.has_value()) { if (sym_dims.has_value()) {
result_shapes[i] = result_shapes[i].with_symbolic_dims(*sym_dims); result_shapes[i] = result_shapes[i].with_symbolic_dims(*sym_dims);

View File

@ -34,7 +34,7 @@
* *
* 3. How to figure out the shape/dtype * 3. How to figure out the shape/dtype
* ------------------------------------ * ------------------------------------
* Unfortunatley there isn't a one-stop-shop for learning the output shape * Unfortunately there isn't a one-stop-shop for learning the output shape
* formulae for all operators. This is partly because some operators are not * formulae for all operators. This is partly because some operators are not
* part of our 'public' API, including backward operators which users don't * part of our 'public' API, including backward operators which users don't
* directly invoke. * directly invoke.
@ -427,8 +427,8 @@ std::vector<Shape> compute_shape_expand(
const at::Tensor& self, const at::Tensor& self,
at::IntArrayRef size, at::IntArrayRef size,
bool implicit) { bool implicit) {
TORCH_CHECK_GE(size.size(), self.dim()); TORCH_CHECK_GE(static_cast<int64_t>(size.size()), self.dim());
int64_t num_new_dimensions = size.size() - self.dim(); size_t num_new_dimensions = size.size() - self.dim();
std::vector<int64_t> padded_self(num_new_dimensions, 0); std::vector<int64_t> padded_self(num_new_dimensions, 0);
padded_self.insert( padded_self.insert(
padded_self.end(), self.sizes().begin(), self.sizes().end()); padded_self.end(), self.sizes().begin(), self.sizes().end());
@ -443,9 +443,9 @@ std::vector<Shape> compute_shape_expand(
const at::Tensor& self, const at::Tensor& self,
c10::SymIntArrayRef size, c10::SymIntArrayRef size,
bool implicit) { bool implicit) {
TORCH_CHECK_GE(size.size(), self.dim()); TORCH_CHECK_GE(static_cast<int64_t>(size.size()), self.dim());
std::vector<c10::SymInt> _sizes = ToVector<c10::SymInt>(size); std::vector<c10::SymInt> _sizes = ToVector<c10::SymInt>(size);
int64_t num_new_dimensions = _sizes.size() - self.dim(); size_t num_new_dimensions = _sizes.size() - self.dim();
std::vector<int64_t> padded_self(num_new_dimensions, 0); std::vector<int64_t> padded_self(num_new_dimensions, 0);
padded_self.insert( padded_self.insert(
padded_self.end(), self.sizes().begin(), self.sizes().end()); padded_self.end(), self.sizes().begin(), self.sizes().end());
@ -1138,8 +1138,8 @@ std::vector<Shape> compute_shape_stack(at::TensorList tensors, int64_t dim) {
std::vector<Shape> compute_shape_repeat( std::vector<Shape> compute_shape_repeat(
const at::Tensor& self, const at::Tensor& self,
at::IntArrayRef repeats) { at::IntArrayRef repeats) {
TORCH_CHECK_GE(repeats.size(), self.dim()); TORCH_CHECK_GE(static_cast<int64_t>(repeats.size()), self.dim());
int64_t num_new_dimensions = repeats.size() - self.dim(); size_t num_new_dimensions = repeats.size() - self.dim();
std::vector<int64_t> padded_size(num_new_dimensions, 1); std::vector<int64_t> padded_size(num_new_dimensions, 1);
padded_size.insert( padded_size.insert(
padded_size.end(), self.sizes().begin(), self.sizes().end()); padded_size.end(), self.sizes().begin(), self.sizes().end());

View File

@ -174,7 +174,7 @@ void LazyTensor::TryLimitGraphSize() {
FLAGS_torch_lazy_trim_graph_check_frequency == FLAGS_torch_lazy_trim_graph_check_frequency ==
0) { 0) {
size_t graph_size = Util::GetGraphSize({data()->ir_value.node.get()}); size_t graph_size = Util::GetGraphSize({data()->ir_value.node.get()});
if (graph_size > FLAGS_torch_lazy_trim_graph_size) { if (static_cast<int64_t>(graph_size) > FLAGS_torch_lazy_trim_graph_size) {
TORCH_LAZY_COUNTER("TrimIrGraph", 1); TORCH_LAZY_COUNTER("TrimIrGraph", 1);
ApplyPendingGraph(); ApplyPendingGraph();
} }

View File

@ -221,7 +221,7 @@ void ts_eager_fallback(
// Step 1: Convert all non-eager tensor inputs into eager tensors and put them // Step 1: Convert all non-eager tensor inputs into eager tensors and put them
// on the stack at the correct indices. // on the stack at the correct indices.
for (int64_t idx = 0; idx < arguments.size(); ++idx) { for (size_t idx = 0; idx < arguments.size(); ++idx) {
const auto& ivalue = arguments[idx]; const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) { if (ivalue.isTensor()) {
tensor_args.push_back(ivalue.toTensor()); tensor_args.push_back(ivalue.toTensor());
@ -246,7 +246,7 @@ void ts_eager_fallback(
// CPU together. // CPU together.
auto eager_tensors = to_eager(tensor_args, device_type); auto eager_tensors = to_eager(tensor_args, device_type);
for (auto i = 0; i < tensor_args_indices.size(); ++i) { for (const auto i : c10::irange(tensor_args_indices.size())) {
auto idx = tensor_args_indices[i]; auto idx = tensor_args_indices[i];
(*stack)[arguments_begin + idx] = c10::IValue(eager_tensors[i]); (*stack)[arguments_begin + idx] = c10::IValue(eager_tensors[i]);
} }
@ -257,7 +257,7 @@ void ts_eager_fallback(
// Step 3: We need to take special care to handle mutable aliases properly: // Step 3: We need to take special care to handle mutable aliases properly:
// If any input tensors are mutable aliases, we need to directly copy the // If any input tensors are mutable aliases, we need to directly copy the
// updated data on the eager tensors back to the original inputs. // updated data on the eager tensors back to the original inputs.
for (int64_t i = 0; i < tensor_args_indices.size(); ++i) { for (const auto i : c10::irange(tensor_args_indices.size())) {
auto tensor_idx = tensor_args_indices[i]; auto tensor_idx = tensor_args_indices[i];
const auto alias_info = schema_args[tensor_idx].alias_info(); const auto alias_info = schema_args[tensor_idx].alias_info();
if (alias_info != nullptr && alias_info->isWrite()) { if (alias_info != nullptr && alias_info->isWrite()) {
@ -288,7 +288,7 @@ void ts_eager_fallback(
auto returns = torch::jit::last(stack, num_returns); auto returns = torch::jit::last(stack, num_returns);
const auto returns_begin = stack->size() - num_returns; const auto returns_begin = stack->size() - num_returns;
for (int64_t idx = 0; idx < returns.size(); ++idx) { for (const auto idx : c10::irange(returns.size())) {
if (returns[idx].isTensor()) { if (returns[idx].isTensor()) {
const auto& return_tens = returns[idx].toTensor(); const auto& return_tens = returns[idx].toTensor();
if (return_tens.defined()) { if (return_tens.defined()) {
@ -299,7 +299,7 @@ void ts_eager_fallback(
bool found_alias = false; bool found_alias = false;
// We could store some extra metadata on the function schema to avoid // We could store some extra metadata on the function schema to avoid
// the loop here if we need to improve perf. // the loop here if we need to improve perf.
for (int64_t i = 0; i < tensor_args_indices.size(); ++i) { for (const auto i : c10::irange(tensor_args_indices.size())) {
auto input_tensor_idx = tensor_args_indices[i]; auto input_tensor_idx = tensor_args_indices[i];
const auto& input_tensor = eager_tensors[i]; const auto& input_tensor = eager_tensors[i];
const auto input_alias_info = const auto input_alias_info =

View File

@ -183,7 +183,7 @@ class ExperimentalConfigWrapper {
configss << "ACTIVITIES_WARMUP_PERIOD_SECS=0\n" configss << "ACTIVITIES_WARMUP_PERIOD_SECS=0\n"
<< "CUPTI_PROFILER_METRICS="; << "CUPTI_PROFILER_METRICS=";
for (int i = 0; i < num_metrics; i++) { for (size_t i = 0; i < num_metrics; i++) {
configss << config_.profiler_metrics[i]; configss << config_.profiler_metrics[i];
if (num_metrics > 1 && i < (num_metrics - 1)) { if (num_metrics > 1 && i < (num_metrics - 1)) {
configss << ","; configss << ",";

View File

@ -165,7 +165,7 @@ void PerfProfiler::Enable() {
start_values_.emplace(events_.size(), 0); start_values_.emplace(events_.size(), 0);
auto& sv = start_values_.top(); auto& sv = start_values_.top();
for (int i = 0; i < events_.size(); ++i) { for (unsigned i = 0; i < events_.size(); ++i) {
sv[i] = events_[i].ReadCounter(); sv[i] = events_[i].ReadCounter();
} }
StartCounting(); StartCounting();
@ -182,7 +182,7 @@ void PerfProfiler::Disable(perf_counters_t& vals) {
/* Always connecting this disable event to the last enable event i.e. using /* Always connecting this disable event to the last enable event i.e. using
* whatever is on the top of the start counter value stack. */ * whatever is on the top of the start counter value stack. */
perf_counters_t& sv = start_values_.top(); perf_counters_t& sv = start_values_.top();
for (int i = 0; i < events_.size(); ++i) { for (unsigned i = 0; i < events_.size(); ++i) {
vals[i] = CalcDelta(sv[i], events_[i].ReadCounter()); vals[i] = CalcDelta(sv[i], events_[i].ReadCounter());
} }
start_values_.pop(); start_values_.pop();