mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE] Enforce sign-compare (#96723)
Number of OSS PR were reverted, because new signed-unsigned comparison warnings, which are treated as errors in some internal builds.
Not sure how those selective rules are applied, but this PR removes `-Wno-sign-compare` from PyTorch codebase.
The only tricky part in this PR, as making sure that non-ASCII character detection works for both signed and unsigned chars here:
6e3d51b08a/torch/csrc/jit/serialization/python_print.cpp (L926)
Exclude several files from sign-compare if flash attention is used, due to the violation in cutlass, to be fixed by https://github.com/NVIDIA/cutlass/pull/869
Do not try to fix sign compare violations in caffe2 codebase
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96723
Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
96c745dfdc
commit
a229e78544
@ -39,7 +39,7 @@ namespace native {
|
||||
|
||||
c10::SmallVector<std::string> get_extra_args_typenames(const c10::SmallVector<at::Scalar>& extra_args) {
|
||||
c10::SmallVector<std::string> args_typenames(extra_args.size());
|
||||
for (auto i = 0; i < extra_args.size(); ++i) {
|
||||
for (const auto i : c10::irange(extra_args.size())) {
|
||||
args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type());
|
||||
}
|
||||
return args_typenames;
|
||||
|
@ -333,7 +333,7 @@ void conv_depthwise_shape_check(
|
||||
if (grad_output.defined()) {
|
||||
auto expected_output_size = conv_output_size(input.sizes(), weight.sizes(),
|
||||
padding, stride, dilation);
|
||||
TORCH_CHECK(grad_output.dim() == expected_output_size.size(),
|
||||
TORCH_CHECK(static_cast<size_t>(grad_output.dim()) == expected_output_size.size(),
|
||||
"Expect grad_output to be ",
|
||||
expected_output_size.size(), "D, got ",
|
||||
grad_output.dim(), "D.");
|
||||
|
@ -132,7 +132,7 @@ FOREACH_BINARY_OP_SCALARLIST(all_types_complex_half_bfloat16, pow, power_functor
|
||||
// In the case of subtraction, we dont allow scalar to be boolean following the torch.sub logic
|
||||
void foreach_tensor_sub_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
|
||||
check_foreach_api_restrictions(tensors, scalars);
|
||||
for (int i = 0; i < tensors.size(); i++) {
|
||||
for (const auto i: c10::irange(tensors.size())) {
|
||||
sub_check(tensors[i], scalars[i]);
|
||||
}
|
||||
|
||||
@ -147,7 +147,7 @@ void foreach_tensor_sub_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef
|
||||
|
||||
std::vector<Tensor> foreach_tensor_sub_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<Scalar> scalars) {
|
||||
check_foreach_api_restrictions(tensors, scalars);
|
||||
for (int i = 0; i < tensors.size(); i++) {
|
||||
for (const auto i: c10::irange(tensors.size())) {
|
||||
sub_check(tensors[i], scalars[i]);
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ static void launch_kernel(int64_t N, const func_t& f) {
|
||||
template <typename func_t>
|
||||
void gpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride, const func_t& f) {
|
||||
int num_indices = index_size.size();
|
||||
AT_ASSERT(num_indices == index_stride.size());
|
||||
AT_ASSERT(static_cast<size_t>(num_indices) == index_stride.size());
|
||||
AT_ASSERT(num_indices == iter.ntensors() - 2);
|
||||
|
||||
if (iter.numel() == 0) {
|
||||
|
@ -226,8 +226,8 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
|
||||
int64_t batch_size = log_probs.size(1);
|
||||
int64_t num_labels = log_probs.size(2);
|
||||
TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range");
|
||||
TORCH_CHECK(input_lengths.size() == batch_size, "input_lengths must be of size batch_size");
|
||||
TORCH_CHECK(target_lengths.size() == batch_size, "target_lengths must be of size batch_size");
|
||||
TORCH_CHECK(input_lengths.size() == static_cast<size_t>(batch_size), "input_lengths must be of size batch_size");
|
||||
TORCH_CHECK(target_lengths.size() == static_cast<size_t>(batch_size), "target_lengths must be of size batch_size");
|
||||
|
||||
int64_t tg_target_stride;
|
||||
|
||||
|
@ -174,7 +174,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
||||
// Now we loop
|
||||
int batchCounter = 0;
|
||||
int64_t offset = 0;
|
||||
for (int i = 0; i < inputs.size() ; i += batch_size) {
|
||||
for (unsigned i = 0; i < inputs.size() ; i += batch_size) {
|
||||
for (batchCounter = 0;
|
||||
batchCounter < batch_size &&
|
||||
(i+batchCounter) < inputs.size();
|
||||
|
@ -44,7 +44,7 @@ struct HermitianSymmetryOffsetCalculator {
|
||||
}
|
||||
|
||||
mirror_dim_ = 0;
|
||||
for (int64_t i = 0; i < dim.size(); ++i) {
|
||||
for (const auto i: c10::irange(dim.size())) {
|
||||
mirror_dim_ |= (uint32_t{1} << dim[i]);
|
||||
}
|
||||
}
|
||||
|
@ -258,7 +258,7 @@ bool CUDA_tensor_histogram(
|
||||
memType = CUDAHistogramMemoryType::SHARED;
|
||||
} else if (
|
||||
nbins < THRESH_NUMBER_BINS_FOR_GLOBAL_MEM &&
|
||||
multiBlockMem < (maxGlobalMem / 2)) {
|
||||
multiBlockMem < static_cast<size_t>(maxGlobalMem / 2)) {
|
||||
// check against half of free mem to be extra safe
|
||||
// due to cached allocator, we may anyway have slightly more free mem
|
||||
memType = CUDAHistogramMemoryType::MULTI_BLOCK;
|
||||
|
@ -141,7 +141,7 @@ void calculate_mode(
|
||||
// to calculate the mode for --> we do this by manually doing the stride
|
||||
// calculations to get an offset
|
||||
scalar_t* data = self.data_ptr<scalar_t>();
|
||||
for (int64_t i = 0; i < position.size(); i++) {
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(position.size()); i++) {
|
||||
data += position[i] * ensure_nonempty_stride(self, i);
|
||||
}
|
||||
|
||||
@ -159,7 +159,7 @@ void calculate_mode(
|
||||
scalar_t* values_data = values.data_ptr<scalar_t>();
|
||||
int64_t* indices_data = indices.data_ptr<int64_t>();
|
||||
|
||||
for (int64_t i = 0; i < position.size(); i++) {
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(position.size()); i++) {
|
||||
int64_t pos = position[i];
|
||||
values_data += ensure_nonempty_stride(values, i) * pos;
|
||||
indices_data += ensure_nonempty_stride(indices, i) * pos;
|
||||
|
@ -796,7 +796,7 @@ void LayerNormKernelImplInternal(
|
||||
constexpr int num_vec_elems = vec_size;
|
||||
constexpr int alignment = num_vec_elems * sizeof(T);
|
||||
if ((std::is_same<T, float>::value || std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) &&
|
||||
N <= 1ULL << std::numeric_limits<float>::digits && N % num_vec_elems == 0 &&
|
||||
N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && N % num_vec_elems == 0 &&
|
||||
can_vectorize(X_data, alignment) && can_vectorize(Y_data, alignment)) {
|
||||
launch_vectorized_layer_norm_kernel(static_cast<int>(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data);
|
||||
} else {
|
||||
@ -1356,10 +1356,10 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_cuda(
|
||||
const size_t axis = input.dim() - normalized_shape.size();
|
||||
|
||||
std::vector<int64_t> stat_shape;
|
||||
for (size_t idx = 0; idx < axis; ++idx) {
|
||||
for (const auto idx: c10::irange(axis)) {
|
||||
stat_shape.push_back(input_shape[idx]);
|
||||
}
|
||||
for (size_t idx = axis; idx < input.dim(); ++idx) {
|
||||
for (const auto C10_UNUSED idx: c10::irange(axis, input.dim())) {
|
||||
stat_shape.push_back(1);
|
||||
}
|
||||
|
||||
|
@ -373,7 +373,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
|
||||
if (remove_invalid) {
|
||||
cudnn_frontend::executionPlans_t new_valid_plans;
|
||||
for (auto &plan : valid_plans) {
|
||||
if (plan.getWorkspaceSize() <= max_workspace_size) {
|
||||
if (static_cast<size_t>(plan.getWorkspaceSize()) <= max_workspace_size) {
|
||||
new_valid_plans.emplace_back(std::move(plan));
|
||||
}
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ size_t compute_strided_size(const at::Tensor& t) {
|
||||
}
|
||||
|
||||
bool is_strided_contiguous(const at::Tensor& t) {
|
||||
return compute_strided_size(t) == t.numel();
|
||||
return compute_strided_size(t) == static_cast<size_t>(t.numel());
|
||||
}
|
||||
|
||||
// Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().
|
||||
|
@ -156,11 +156,11 @@ static void validateInputData(const TensorIteratorBase& iter,
|
||||
bool accumulate) {
|
||||
using namespace mps;
|
||||
|
||||
int64_t num_indices = index_size.size();
|
||||
const auto num_indices = index_size.size();
|
||||
TORCH_CHECK(num_indices <= 16, "Current limit allows up to 16 indices to be used in MPS indexing kernels");
|
||||
|
||||
AT_ASSERT(num_indices == index_stride.size());
|
||||
AT_ASSERT(num_indices == iter.ntensors() - 2);
|
||||
AT_ASSERT(static_cast<int>(num_indices) == iter.ntensors() - 2);
|
||||
const Tensor& inputTensor = iter.tensor(1);
|
||||
|
||||
if (accumulate) {
|
||||
@ -589,8 +589,8 @@ Tensor index_select_mps(const Tensor& self, int64_t dim, const Tensor& index) {
|
||||
std::vector<int64_t> shape_data(num_input_dims);
|
||||
|
||||
// Calculate new shape
|
||||
for (auto i : c10::irange(num_input_dims)) {
|
||||
if (i == dim) {
|
||||
for (const auto i : c10::irange(num_input_dims)) {
|
||||
if (i == static_cast<decltype(i)>(dim)) {
|
||||
shape_data[i] = num_indices;
|
||||
} else {
|
||||
shape_data[i] = input_shape[i];
|
||||
|
@ -1000,21 +1000,21 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(const Tensor& grad_ou
|
||||
|
||||
NSMutableArray<NSNumber*>* gamma_axes = [NSMutableArray<NSNumber*> arrayWithCapacity:num_channel_dims];
|
||||
|
||||
for (int i = 0; i < num_channel_dims; i++)
|
||||
gamma_axes[i] = [NSNumber numberWithInt:i];
|
||||
for (const auto i : c10::irange(num_channel_dims))
|
||||
gamma_axes[i] = [NSNumber numberWithInt:static_cast<int>(i)];
|
||||
|
||||
// Axes along which to reduce to get "batch norm" gradient
|
||||
// This will be applied on shape [1, M, -1]
|
||||
NSMutableArray<NSNumber*>* bn_axes = [NSMutableArray<NSNumber*> arrayWithCapacity:num_normalized_dims];
|
||||
for (int i = 0; i < num_normalized_dims; i++)
|
||||
bn_axes[i] = [NSNumber numberWithInt:(1 + 1 + i)];
|
||||
for (const auto i : c10::irange(num_normalized_dims))
|
||||
bn_axes[i] = [NSNumber numberWithInt:static_cast<int>(1 + 1 + i)];
|
||||
|
||||
// Shape of input to do "batch norm" backward
|
||||
// This is [1, M, -1]
|
||||
NSMutableArray<NSNumber*>* bn_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)];
|
||||
bn_shape[0] = [NSNumber numberWithInt:1];
|
||||
bn_shape[1] = [NSNumber numberWithInt:M];
|
||||
for (int i = 0; i < num_normalized_dims; i++)
|
||||
for (const auto i : c10::irange(num_normalized_dims))
|
||||
bn_shape[i + 2] = input_shape[i + num_channel_dims];
|
||||
|
||||
// Shape of mean to do "batch norm" backward
|
||||
@ -1023,7 +1023,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(const Tensor& grad_ou
|
||||
[NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)];
|
||||
bn_mean_shape[0] = [NSNumber numberWithInt:1];
|
||||
bn_mean_shape[1] = [NSNumber numberWithInt:M];
|
||||
for (int i = 0; i < num_normalized_dims; i++)
|
||||
for (const auto i : c10::irange(num_normalized_dims))
|
||||
bn_mean_shape[i + 2] = [NSNumber numberWithInt:1];
|
||||
|
||||
// Shape of gamma to multiply with "batch norm" backward
|
||||
@ -1032,7 +1032,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(const Tensor& grad_ou
|
||||
[NSMutableArray<NSNumber*> arrayWithCapacity:(num_normalized_dims + 2)];
|
||||
bn_gamma_shape[0] = [NSNumber numberWithInt:1];
|
||||
bn_gamma_shape[1] = [NSNumber numberWithInt:1];
|
||||
for (int i = 0; i < num_normalized_dims; i++)
|
||||
for (const auto i : c10::irange(num_normalized_dims))
|
||||
bn_gamma_shape[i + 2] = input_shape[i + num_channel_dims];
|
||||
|
||||
string key = "layer_norm_backward_mps:" + std::to_string(has_weight) + ":" +
|
||||
|
@ -136,8 +136,9 @@ void reduction_out_mps(const Tensor& input_t,
|
||||
IntArrayRef dim = opt_dim.value();
|
||||
for (const auto dim_val : dim) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
|
||||
TORCH_CHECK(wrap_dim < (input_shape.size() == 0 ? input_t.numel() : input_shape.size()),
|
||||
func_name + ": reduction dim must be in the range of input shape")
|
||||
TORCH_CHECK(
|
||||
wrap_dim < static_cast<decltype(wrap_dim)>(input_shape.size() == 0 ? input_t.numel() : input_shape.size()),
|
||||
func_name + ": reduction dim must be in the range of input shape")
|
||||
}
|
||||
}
|
||||
|
||||
@ -395,7 +396,8 @@ void impl_func_norm_mps(const Tensor& input_tensor,
|
||||
|
||||
for (const auto dim_val : dim) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
|
||||
TORCH_CHECK(wrap_dim < input_shape.size(), "norm_out_mps: reduction dim must be in the range of input shape")
|
||||
TORCH_CHECK(wrap_dim < static_cast<decltype(wrap_dim)>(input_shape.size()),
|
||||
"norm_out_mps: reduction dim must be in the range of input shape")
|
||||
}
|
||||
|
||||
auto cache_ = MPSGraphCache::getInstance();
|
||||
@ -663,8 +665,8 @@ Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
string errMessage = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
|
||||
errMessage += ": reduction dim must be in the range of input shape";
|
||||
for (const auto dim : dim_value) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim, input_shape.size());
|
||||
TORCH_CHECK(wrap_dim < input_shape.size(), errMessage.c_str())
|
||||
auto wrap_dim = maybe_wrap_dim(dim, num_input_dims);
|
||||
TORCH_CHECK(wrap_dim < static_cast<decltype(wrap_dim)>(input_shape.size()), errMessage.c_str())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -207,7 +207,7 @@ void computeRepeatIndices(index_t* repeat_ptr,
|
||||
[computeEncoder setBytes:&size length:sizeof(size) atIndex:3];
|
||||
MTLSize gridSize = MTLSizeMake(size, 1, 1);
|
||||
NSUInteger threadsPerThreadgroup_ = pipelineState.maxTotalThreadsPerThreadgroup;
|
||||
if (threadsPerThreadgroup_ > size) {
|
||||
if (threadsPerThreadgroup_ > static_cast<NSUInteger>(size)) {
|
||||
threadsPerThreadgroup_ = size;
|
||||
}
|
||||
MTLSize threadsPerThreadgroup = MTLSizeMake(threadsPerThreadgroup_, 1, 1);
|
||||
|
@ -17,7 +17,7 @@ namespace at::native {
|
||||
std::vector<long long> getTensorShape(MPSGraphTensor* mpsTensor) {
|
||||
std::vector<long long> output_dimensions = {};
|
||||
auto dims = mpsTensor.shape;
|
||||
for (int i = 0; i < [dims count]; i++) {
|
||||
for (NSUInteger i = 0; i < [dims count]; i++) {
|
||||
output_dimensions.push_back([dims[i] intValue]);
|
||||
}
|
||||
return output_dimensions;
|
||||
|
@ -97,7 +97,7 @@ std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self,
|
||||
if (dimOpt.has_value() && [shape count] != 1) {
|
||||
NSMutableArray* axes = [[NSMutableArray alloc] initWithCapacity:[shape count] - 1];
|
||||
for (const auto axis : c10::irange([shape count])) {
|
||||
if (axis != dim) {
|
||||
if (static_cast<decltype(dim)>(axis) != dim) {
|
||||
[axes addObject:[NSNumber numberWithUnsignedInteger:axis]];
|
||||
}
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
|
||||
feeds[cachedGraph->storageOffsetTensor] = getMPSGraphTensorFromScalar(stream, storageOffsetScalar);
|
||||
|
||||
std::vector<MPSScalar> strideScalars(sizes.size());
|
||||
for (int i = 0; i < sizes.size(); i++) {
|
||||
for (const auto i : c10::irange(sizes.size())) {
|
||||
strideScalars[i] = getMPSScalar(strides[i], ScalarType::Int);
|
||||
feeds[cachedGraph->strideTensors[i]] = getMPSGraphTensorFromScalar(stream, strideScalars[i]);
|
||||
}
|
||||
@ -133,7 +133,7 @@ NSDictionary* getStrideToDimLengthOffsetDict(MPSGraphTensor* tensor, NSUInteger
|
||||
// Detect only expand dims, allows for duplicate strides
|
||||
MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
size_t dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
@ -185,7 +185,7 @@ MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph* graph,
|
||||
// Detect contiguous reshapes, no slicing
|
||||
MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
size_t dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
@ -228,7 +228,7 @@ MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph* graph,
|
||||
|
||||
MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
size_t dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
@ -236,7 +236,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
{
|
||||
BOOL allUnique = YES;
|
||||
NSMutableSet* uniqueStrides = [[NSMutableSet alloc] init];
|
||||
for (NSInteger dstDim = 0; (dstDim < dstRank) && allUnique; dstDim++) {
|
||||
for (NSUInteger dstDim = 0; (dstDim < dstRank) && allUnique; dstDim++) {
|
||||
int stride = dstStrides[dstDim];
|
||||
NSNumber* strideObj = [NSNumber numberWithInt:stride];
|
||||
allUnique &= (stride == 0 || ![uniqueStrides containsObject:strideObj]);
|
||||
@ -247,7 +247,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
return nil;
|
||||
|
||||
// Skip for zero in dst shape
|
||||
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++)
|
||||
for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++)
|
||||
if (dstSizes[dstDim] == 0) {
|
||||
return nil;
|
||||
}
|
||||
@ -277,7 +277,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
std::vector<int32_t> dstDimToSliceOffset(dstRank);
|
||||
bool needsBroadcast = false;
|
||||
{
|
||||
for (NSInteger dstDim = dstRank - 1; dstDim >= 0; dstDim--) {
|
||||
for (auto dstDim = dstRank - 1; dstDim >= 0; dstDim--) {
|
||||
if (dstStrides[dstDim] == 0) {
|
||||
// This dimension should be a broadcast
|
||||
needsBroadcast = true;
|
||||
@ -318,7 +318,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
[missingSrcStrides addObject:[NSNumber numberWithInteger:stride]];
|
||||
stride *= [[flatInputTensor shape][srcDim] integerValue];
|
||||
}
|
||||
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) {
|
||||
for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++) {
|
||||
[missingSrcStrides removeObject:[NSNumber numberWithInteger:dstStrides[dstDim]]];
|
||||
}
|
||||
}
|
||||
@ -344,7 +344,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
// TODO: Use Transpose API
|
||||
BOOL needsTranspose = NO;
|
||||
for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++)
|
||||
needsTranspose |= ([dstDimOrder[dstDim] intValue] != dstDim);
|
||||
needsTranspose |= ([dstDimOrder[dstDim] intValue] != static_cast<int>(dstDim));
|
||||
if (needsTranspose)
|
||||
transposedTensor = permuteTensor(graph, transposedTensor, dstDimOrder);
|
||||
}
|
||||
@ -385,7 +385,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
if (needsBroadcast) {
|
||||
NSMutableArray* broadcastShape = [[NSMutableArray alloc] init];
|
||||
NSMutableArray* expandAxes = [[NSMutableArray alloc] init];
|
||||
for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) {
|
||||
for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++) {
|
||||
[broadcastShape addObject:[NSNumber numberWithInt:dstSizes[dstDim]]];
|
||||
if (dstStrides[dstDim] == 0)
|
||||
[expandAxes addObject:[NSNumber numberWithInt:dstDim]];
|
||||
@ -408,7 +408,7 @@ MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph* graph,
|
||||
|
||||
MPSGraphTensor* asStridedLayer_pattern(MPSGraph* graph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
int dstRank,
|
||||
size_t dstRank,
|
||||
const IntArrayRef& dstSizes,
|
||||
const IntArrayRef& dstStrides,
|
||||
int offset) {
|
||||
@ -503,7 +503,7 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape* mp
|
||||
MPSNDArrayDescriptor* srcTensorNDArrayDesc = nil;
|
||||
MPSNDArray* srcTensorNDArray = nil;
|
||||
id<MTLCommandBuffer> commandBuffer = getCurrentMPSStream()->commandBuffer();
|
||||
int64_t base_idx = 0;
|
||||
size_t base_idx = 0;
|
||||
|
||||
std::vector<int64_t> src_base_shape_vec;
|
||||
|
||||
@ -574,7 +574,7 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph,
|
||||
@autoreleasepool {
|
||||
std::vector<int32_t> sizeArray(shape_size);
|
||||
const int64_t int_max = std::numeric_limits<int32_t>::max();
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
for (const auto i : c10::irange(shape_size)) {
|
||||
TORCH_CHECK(size[i] <= int_max);
|
||||
sizeArray[i] = static_cast<int32_t>(size[i]);
|
||||
}
|
||||
@ -584,7 +584,7 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph,
|
||||
dataType:MPSDataTypeInt32];
|
||||
MPSGraphTensor* indicesTensor = nil;
|
||||
// create stride Tensors for each rank of the input tensor
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
for (int i = 0; i < static_cast<int>(shape_size); i++) {
|
||||
MPSGraphTensor* rangeTensor = [mpsGraph coordinateAlongAxis:(-i - 1) withShapeTensor:shapeTensor name:nil];
|
||||
MPSGraphTensor* strideTensor = cachedGraph->strideTensors[shape_size - i - 1];
|
||||
MPSGraphTensor* indexTensor = [mpsGraph multiplicationWithPrimaryTensor:rangeTensor
|
||||
@ -702,7 +702,7 @@ static ViewCachedGraph* createViewGraph(const Tensor& self,
|
||||
// Self is the input tensor we are creating view of
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape));
|
||||
newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]);
|
||||
for (int i = 0; i < size.size(); i++) {
|
||||
for (const auto C10_UNUSED i : c10::irange(size.size())) {
|
||||
newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]));
|
||||
}
|
||||
if (needsScatter) {
|
||||
@ -837,7 +837,7 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) {
|
||||
if (kernel_size == 0) {
|
||||
src_sizes[0] = src_strides[0] = 1;
|
||||
} else {
|
||||
for (int i = 0; i < kernel_size; i++) {
|
||||
for (const auto i : c10::irange(kernel_size)) {
|
||||
src_sizes[i] = (uint32_t)(src.sizes()[i]);
|
||||
src_strides[i] = (uint32_t)(src.strides()[i]);
|
||||
}
|
||||
|
@ -656,6 +656,15 @@ if(USE_CUDA)
|
||||
set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/interface.cpp PROPERTIES COMPILE_FLAGS "-DUSE_CUDA=1")
|
||||
endif()
|
||||
|
||||
if(USE_FLASH_ATTENTION AND NOT MSVC)
|
||||
# Cutlass contains a sign-compare violation in its codebase to be fixed by https://github.com/NVIDIA/cutlass/pull/869
|
||||
set_source_files_properties(
|
||||
"${PROJECT_SOURCE_DIR}/aten/src/ATen/native/nested/cuda/NestedTensorMatmul.cu"
|
||||
"${PROJECT_SOURCE_DIR}/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu"
|
||||
"${PROJECT_SOURCE_DIR}/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_bwd_hdim128.cu"
|
||||
PROPERTIES COMPILE_FLAGS "-Wno-sign-compare")
|
||||
endif()
|
||||
|
||||
if(BUILD_ONEDNN_GRAPH)
|
||||
list(APPEND Caffe2_CPU_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp
|
||||
@ -811,9 +820,10 @@ if(HAVE_SOVERSION)
|
||||
VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION})
|
||||
endif()
|
||||
torch_compile_options(torch_cpu) # see cmake/public/utils.cmake
|
||||
if(HAS_WERROR_SIGN_COMPARE AND WERROR)
|
||||
# target_compile_options(torch_cpu PRIVATE "-Werror=sign-compare")
|
||||
set_property(SOURCE ${ATen_CORE_SRCS} ${ATen_CPU_SRCS} APPEND PROPERTY COMPILE_OPTIONS "-Werror=sign-compare")
|
||||
if(BUILD_CAFFE2 AND NOT MSVC)
|
||||
# Caffe2 has too many signed-unsigned violation, but the framework is dead
|
||||
# So no point in fixing those
|
||||
target_compile_options(torch_cpu PRIVATE "-Wno-sign-compare")
|
||||
endif()
|
||||
|
||||
set_property(SOURCE ${ATen_CORE_SRCS} APPEND
|
||||
|
@ -17,8 +17,8 @@ void BoxCoxNaive(
|
||||
T* output_ptr) {
|
||||
constexpr T k_eps = static_cast<T>(1e-6);
|
||||
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
for (int64_t j = 0; j < D; j++, data_ptr++, output_ptr++) {
|
||||
for (std::size_t i = 0; i < N; i++) {
|
||||
for (std::size_t j = 0; j < D; j++, data_ptr++, output_ptr++) {
|
||||
T lambda1_v = lambda1_ptr[j];
|
||||
T lambda2_v = lambda2_ptr[j];
|
||||
T tmp = std::max(*data_ptr + lambda2_v, k_eps);
|
||||
|
@ -36,8 +36,7 @@ void FloatToFused8BitRowwiseQuantized__base(
|
||||
output_row_scale_bias[0] = range / 255.0f;
|
||||
output_row_scale_bias[1] = minimum_element;
|
||||
const auto inverse_scale = 255.0f / (range + kEpsilon);
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
for (std::size_t col = 0; col < input_columns; ++col) {
|
||||
for (std::size_t col = 0; col < static_cast<size_t>(input_columns); ++col) {
|
||||
output_row[col] =
|
||||
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
|
||||
}
|
||||
@ -58,8 +57,7 @@ void Fused8BitRowwiseQuantizedToFloat__base(
|
||||
reinterpret_cast<const float*>(input_row + output_columns);
|
||||
float* output_row = output + row * output_columns;
|
||||
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
for (std::size_t col = 0; col < output_columns; ++col) {
|
||||
for (std::size_t col = 0; col < static_cast<std::size_t>(output_columns); ++col) {
|
||||
output_row[col] =
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
|
||||
@ -136,8 +134,7 @@ void FloatToFusedNBitRowwiseQuantizedSBHalf__base(
|
||||
|
||||
output_row_scale_bias[0] = scale;
|
||||
output_row_scale_bias[1] = minimum_element;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
for (std::size_t col = 0; col < input_columns; ++col) {
|
||||
for (std::size_t col = 0; col < static_cast<size_t>(input_columns); ++col) {
|
||||
float X = input_row[col];
|
||||
std::uint8_t quantized = std::max(
|
||||
0,
|
||||
@ -165,7 +162,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloat__base(
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
(input_columns - 2 * sizeof(at::Half)) * num_elem_per_byte;
|
||||
|
||||
for (std::size_t row = 0; row < input_rows; ++row) {
|
||||
for (std::size_t row = 0; row < static_cast<size_t>(input_rows); ++row) {
|
||||
const std::uint8_t* input_row = input + row * input_columns;
|
||||
const at::Half* input_row_scale_bias = reinterpret_cast<const at::Half*>(
|
||||
input_row +
|
||||
@ -174,8 +171,7 @@ void FusedNBitRowwiseQuantizedSBHalfToFloat__base(
|
||||
float bias = input_row_scale_bias[1];
|
||||
float* output_row = output + row * output_columns;
|
||||
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
for (std::size_t col = 0; col < output_columns; ++col) {
|
||||
for (std::size_t col = 0; col < static_cast<std::size_t>(output_columns); ++col) {
|
||||
std::uint8_t quantized = input_row[col / num_elem_per_byte];
|
||||
quantized >>= (col % num_elem_per_byte) * bit_rate;
|
||||
quantized &= (1 << bit_rate) - 1;
|
||||
|
@ -150,8 +150,7 @@ void PyTorchStreamReader::init() {
|
||||
version,
|
||||
" as Long Long.");
|
||||
}
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
if (version_ < kMinSupportedFileFormatVersion) {
|
||||
if (version_ < static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
|
||||
CAFFE_THROW(
|
||||
"Attempted to read a PyTorch file with version ",
|
||||
c10::to_string(version_),
|
||||
@ -161,8 +160,7 @@ void PyTorchStreamReader::init() {
|
||||
" with latest version of PyTorch to mitigate this issue.");
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
if (version_ > kMaxSupportedFileFormatVersion) {
|
||||
if (version_ > static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
|
||||
CAFFE_THROW(
|
||||
"Attempted to read a PyTorch file with version ",
|
||||
version_,
|
||||
|
@ -443,7 +443,6 @@ function(torch_compile_options libname)
|
||||
-Wno-type-limits
|
||||
-Wno-array-bounds
|
||||
-Wno-unknown-pragmas
|
||||
-Wno-sign-compare
|
||||
-Wno-strict-overflow
|
||||
-Wno-strict-aliasing
|
||||
-Wno-error=deprecated-declarations
|
||||
|
@ -718,7 +718,7 @@ FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate(
|
||||
if (${out_arg}_new_fw_grad_opt.has_value()) {
|
||||
auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value();
|
||||
TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size());
|
||||
for (auto i=0; i<${out_arg}.size(); ++i) {
|
||||
for (const auto i : c10::irange(${out_arg}.size())) {
|
||||
if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) {
|
||||
// The hardcoded 0 here will need to be updated once we support multiple levels.
|
||||
${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace});
|
||||
|
@ -178,12 +178,10 @@ std::vector<int64_t> ConvTransposeNdImpl<D, Derived>::_output_padding(
|
||||
ret = at::IntArrayRef(this->options.output_padding()).vec();
|
||||
} else {
|
||||
auto k = input.dim() - 2;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
if (output_size_.value().size() == k + 2) {
|
||||
if (output_size_.value().size() == static_cast<size_t>(k + 2)) {
|
||||
output_size_ = output_size_.value().slice(2);
|
||||
}
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
if (output_size_.value().size() != k) {
|
||||
if (output_size_.value().size() != static_cast<size_t>(k)) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"output_size must have ",
|
||||
|
@ -192,7 +192,7 @@ struct AddGenericMetadata : public MetadataBase {
|
||||
|
||||
if (config_ && !config_->experimental_config.performance_events.empty()) {
|
||||
auto& event_names = config_->experimental_config.performance_events;
|
||||
for (auto i = 0; i < op_event.perf_event_counters_->size(); ++i) {
|
||||
for (const auto i : c10::irange(op_event.perf_event_counters_->size())) {
|
||||
addMetadata(
|
||||
event_names[i],
|
||||
std::to_string((*op_event.perf_event_counters_)[i]));
|
||||
|
@ -251,8 +251,7 @@ std::vector<at::Tensor>& scatter_out(
|
||||
out_tensors[i].device(),
|
||||
"'");
|
||||
auto out_sizes = out_tensors[i].sizes().vec();
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
bool same_ndim = out_sizes.size() == tensor.dim();
|
||||
bool same_ndim = out_sizes.size() == static_cast<size_t>(tensor.dim());
|
||||
if (same_ndim) {
|
||||
total_size += out_sizes[dim];
|
||||
chunk_sizes.emplace_back(out_sizes[dim]);
|
||||
|
@ -265,7 +265,7 @@ void check_inputs(
|
||||
int root,
|
||||
int input_multiplier,
|
||||
int output_multiplier) {
|
||||
size_t len = inputs.size();
|
||||
auto len = inputs.size();
|
||||
|
||||
if (len <= 0) {
|
||||
throw std::runtime_error("input sequence can't be empty");
|
||||
@ -280,7 +280,8 @@ void check_inputs(
|
||||
|
||||
check_tensor(
|
||||
input,
|
||||
i == root ? at::optional<at::Tensor>{output} : at::nullopt,
|
||||
i == static_cast<decltype(i)>(root) ? at::optional<at::Tensor>{output}
|
||||
: at::nullopt,
|
||||
input_multiplier,
|
||||
output_multiplier,
|
||||
numel,
|
||||
@ -482,7 +483,7 @@ void reduce(
|
||||
ncclComm_t comm = comms_ref[i];
|
||||
NCCL_CHECK(ncclReduce(
|
||||
inputs[i].data_ptr(),
|
||||
root == i ? output.data_ptr() : nullptr,
|
||||
static_cast<decltype(i)>(root) == i ? output.data_ptr() : nullptr,
|
||||
count,
|
||||
data_type,
|
||||
to_nccl_red_op(op),
|
||||
|
@ -124,8 +124,7 @@ std::unique_ptr<RpcWithProfilingResp> RpcWithProfilingResp::fromMessage(
|
||||
for (const auto i : c10::irange(
|
||||
kProfileEventsStartIdx,
|
||||
kProfileEventsStartIdx + profiledEventsSize)) {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
TORCH_CHECK(i < tupleElements.size());
|
||||
TORCH_CHECK(static_cast<size_t>(i) < tupleElements.size());
|
||||
// Reconstruct remote event from the ivalues.
|
||||
torch::autograd::profiler::LegacyEvent fromIvalueEvent =
|
||||
torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]);
|
||||
|
@ -1985,7 +1985,7 @@ c10::intrusive_ptr<Work> ProcessGroupGloo::allgather_coalesced(
|
||||
invalidArgument("requires non-empty input tensor list");
|
||||
}
|
||||
|
||||
if (output_lists.size() != getSize()) {
|
||||
if (output_lists.size() != static_cast<size_t>(getSize())) {
|
||||
invalidArgument("output lists should be equal to world size");
|
||||
}
|
||||
|
||||
@ -2813,7 +2813,8 @@ void ProcessGroupGloo::monitoredBarrier(
|
||||
// some ranks have not responded.
|
||||
// Ensure all ranks from 1, ... WORLD_SIZE -1 have been successfully
|
||||
// processed.
|
||||
auto rankFailure = (processedRanks.size() != size_ - 1);
|
||||
auto rankFailure =
|
||||
(processedRanks.size() != static_cast<size_t>(size_ - 1));
|
||||
if (waitAllRanks && rankFailure) {
|
||||
std::vector<int> failedRanks;
|
||||
for (const auto i : c10::irange(1, size_)) {
|
||||
|
@ -774,10 +774,10 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall(
|
||||
std::vector<at::Tensor>& inputTensors,
|
||||
const AllToAllOptions& opts) {
|
||||
TORCH_CHECK(
|
||||
inputTensors.size() == size_,
|
||||
inputTensors.size() == static_cast<size_t>(size_),
|
||||
"Number of input tensors are not equal to group size");
|
||||
TORCH_CHECK(
|
||||
outputTensors.size() == size_,
|
||||
outputTensors.size() == static_cast<size_t>(size_),
|
||||
"Number of output tensors are not equal to group size");
|
||||
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
|
||||
[this](std::unique_ptr<WorkEntry>& entry) {
|
||||
|
@ -154,7 +154,7 @@ struct CollectiveFingerPrint {
|
||||
for (const auto i : c10::irange(output_tensors.size())) {
|
||||
const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
|
||||
const at::Tensor reference_tensor = tensors_to_verify[i];
|
||||
for (int rank = 0; rank < gathered_tensors.size(); rank++) {
|
||||
for (const auto rank : c10::irange(gathered_tensors.size())) {
|
||||
const auto& rank_tensor = gathered_tensors[rank];
|
||||
if (!rank_tensor.equal(reference_tensor)) {
|
||||
CollectiveFingerPrint rank_fingerprint =
|
||||
|
@ -1035,7 +1035,7 @@ void TCPStore::waitForWorkers() {
|
||||
auto buf = reinterpret_cast<const char*>(value.data());
|
||||
auto len = value.size();
|
||||
int numWorkersCompleted = std::stoi(std::string(buf, len));
|
||||
if (numWorkersCompleted >= *numWorkers_) {
|
||||
if (numWorkersCompleted >= static_cast<int>(*numWorkers_)) {
|
||||
break;
|
||||
}
|
||||
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
|
@ -2149,7 +2149,7 @@ void verify_params_across_processes(
|
||||
std::vector<std::vector<at::Tensor>> param_size_output_tensors;
|
||||
param_size_output_tensors.emplace_back();
|
||||
auto world_size = process_group->getSize();
|
||||
for (size_t i = 0; i < world_size; ++i) {
|
||||
for (C10_UNUSED const auto i : c10::irange(world_size)) {
|
||||
param_size_output_tensors.front().emplace_back(
|
||||
at::empty_like(param_size_tensor));
|
||||
}
|
||||
@ -2157,10 +2157,10 @@ void verify_params_across_processes(
|
||||
std::vector<at::Tensor> param_size_vec{param_size_tensor};
|
||||
process_group->allgather(param_size_output_tensors, param_size_vec)->wait();
|
||||
auto result_size_tensors = param_size_output_tensors.front();
|
||||
for (size_t i = 0; i < world_size; ++i) {
|
||||
for (const auto i : c10::irange(world_size)) {
|
||||
auto param_size_for_rank = result_size_tensors[i][0].item<int>();
|
||||
TORCH_CHECK(
|
||||
param_size_for_rank == params.size(),
|
||||
static_cast<size_t>(param_size_for_rank) == params.size(),
|
||||
c10::str(
|
||||
"DDP expects same model across all ranks, but Rank ",
|
||||
process_group->getRank(),
|
||||
|
@ -502,8 +502,8 @@ std::vector<at::IValue> readWrappedPayload(
|
||||
payload.resize(indexToRead);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
payload.size() > additionalPayloadSize,
|
||||
static_cast<decltype(additionalPayloadSize)>(payload.size()) >
|
||||
additionalPayloadSize,
|
||||
"Wrong payload sizes: payload.size() is ",
|
||||
payload.size(),
|
||||
" but additional payload size is ",
|
||||
|
@ -84,7 +84,7 @@ ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
|
||||
GRAPH_DEBUG("Initializing graph input logical tensors");
|
||||
std::map<size_t, int64_t> tensorIdToOccurence =
|
||||
initializeTensorIdToOccurence();
|
||||
for (size_t i = 0; i < nGraphInputs_; i++) {
|
||||
for (const auto i : c10::irange(nGraphInputs_)) {
|
||||
auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]);
|
||||
initializedInputIds_.insert(spec.tid());
|
||||
int64_t occurence = tensorIdToOccurence[spec.tid()];
|
||||
@ -95,7 +95,8 @@ ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
|
||||
initializeConstantInputs();
|
||||
|
||||
TORCH_CHECK(
|
||||
inputSpecs.size() + constantValues_.size() == nPartitionInputs_,
|
||||
inputSpecs.size() + constantValues_.size() ==
|
||||
static_cast<size_t>(nPartitionInputs_),
|
||||
"Partition inputs are missing");
|
||||
GRAPH_DEBUG(
|
||||
"Concatenating constant input logical tensors to graph input "
|
||||
@ -111,7 +112,7 @@ ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
|
||||
ArgSpecs LlgaKernel::initializeOutputSpecs() const {
|
||||
ArgSpecs outputSpecs;
|
||||
outputSpecs.reserve(nOutputs_);
|
||||
for (size_t i = 0; i < nOutputs_; i++) {
|
||||
for (const auto i : c10::irange(nOutputs_)) {
|
||||
auto spec = ArgSpec(graph_->outputs()[i]);
|
||||
if (useOpaqueLayout(i)) {
|
||||
spec = spec.any();
|
||||
@ -126,7 +127,7 @@ std::tuple<RunArgs, RunArgs> LlgaKernel::prepareRunArgs(
|
||||
TensorArgs& outputs) const {
|
||||
RunArgs runInputs, runOutputs;
|
||||
auto numInputs = runArgsIdx_.size();
|
||||
for (size_t i = 0; i < numInputs; i++) {
|
||||
for (const auto i : c10::irange(numInputs)) {
|
||||
auto spec = inputSpecs_[i];
|
||||
auto input = inputs[runArgsIdx_[i]];
|
||||
runInputs.push_back(
|
||||
@ -143,7 +144,7 @@ std::tuple<RunArgs, RunArgs> LlgaKernel::prepareRunArgs(
|
||||
constantInputs_[i].data_ptr()});
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nOutputs_; i++) {
|
||||
for (const auto i : c10::irange(nOutputs_)) {
|
||||
auto spec = outputSpecs_[i];
|
||||
auto opt = c10::TensorOptions(spec.aten_scalar_type()).device(device_);
|
||||
|
||||
@ -215,7 +216,7 @@ compiled_partition LlgaKernel::compile(const partition& partition) {
|
||||
|
||||
// Since layouts of opaque outputs would be known after compilation,
|
||||
// we need to query them out from compilation and update outputSpecs
|
||||
for (size_t i = 0; i < nOutputs_; i++) {
|
||||
for (const auto i : c10::irange(nOutputs_)) {
|
||||
auto tid = outputSpecs_[i].tid();
|
||||
outputSpecs_[i] = compilation.query_logical_tensor(tid);
|
||||
}
|
||||
|
@ -520,7 +520,7 @@ void IRParser::parseOperator(Block* b) {
|
||||
const FunctionSchema* schema = n->maybeSchema();
|
||||
|
||||
// Register outputs.
|
||||
int idx = 0;
|
||||
unsigned idx = 0;
|
||||
for (const VarWithType& v : outs) {
|
||||
vmap[v.name] = n->outputs()[idx];
|
||||
if (schema && !schema->is_varret()) {
|
||||
|
@ -733,7 +733,7 @@ void FlatbufferLoader::extractJitSourceAndConstants(
|
||||
}
|
||||
}
|
||||
const auto* jit_constants = module_->jit_constants();
|
||||
for (auto i = 0; i < jit_constants->size(); ++i) {
|
||||
for (const auto i : c10::irange(jit_constants->size())) {
|
||||
constants->emplace_back(getIValue(jit_constants->Get(i)));
|
||||
}
|
||||
parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
|
||||
|
@ -68,7 +68,7 @@ bool Function::initialize_operators(bool should_check_operators) {
|
||||
std::unordered_set<std::string> unsupported_op_names;
|
||||
code_.operators_.resize(code_.op_names_.size());
|
||||
bool all_ops_supported = true;
|
||||
for (int i = 0; i < code_.op_names_.size(); i++) {
|
||||
for (unsigned i = 0; i < code_.op_names_.size(); i++) {
|
||||
const auto& opname = code_.op_names_[i];
|
||||
int num_args = code_.operator_input_sizes_[i];
|
||||
c10::optional<int> num_specified_args =
|
||||
@ -212,7 +212,7 @@ c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
|
||||
stack.pop_back();
|
||||
}
|
||||
TORCH_CHECK(
|
||||
num_specified_args.value() >= out_args.size(),
|
||||
static_cast<size_t>(num_specified_args.value()) >= out_args.size(),
|
||||
"The number of output arguments is: ",
|
||||
out_args.size(),
|
||||
", which is more then the number of specified arguments: ",
|
||||
|
@ -196,8 +196,7 @@ bool InterpreterState::run(Stack& stack) {
|
||||
auto userObj = pop(stack).toObject();
|
||||
// Mobile only: since the number of slots is not known, resize the
|
||||
// numAttributes before setSlot.
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
while (userObj->type()->numAttributes() <= inst.X) {
|
||||
while (static_cast<int>(userObj->type()->numAttributes()) <= inst.X) {
|
||||
std::stringstream ss;
|
||||
ss << userObj->type()->numAttributes();
|
||||
userObj->type()->addAttribute(ss.str(), c10::NoneType::get());
|
||||
|
@ -63,7 +63,7 @@ bool Module::compareMethodSchemas(
|
||||
|
||||
void Module::unsafeRemoveMethod(const std::string& basename) {
|
||||
int64_t i = 0;
|
||||
for (; i < cu_->methods().size(); ++i) {
|
||||
for (; i < static_cast<int64_t>(cu_->methods().size()); ++i) {
|
||||
if ((cu_->methods()[i])->name() == basename) {
|
||||
break;
|
||||
}
|
||||
@ -334,7 +334,7 @@ TORCH_API ModuleInfo get_module_info(const mobile::Module& module) {
|
||||
std::vector<std::string> type_name_list;
|
||||
for (const auto& func_ptr : module.compilation_unit().methods()) {
|
||||
const auto& function = *func_ptr;
|
||||
for (int i = 0; i < function.get_code().op_names_.size(); i++) {
|
||||
for (const auto i : c10::irange(function.get_code().op_names_.size())) {
|
||||
const auto& op = function.get_code().op_names_[i];
|
||||
minfo.opname_to_num_args[mobile::operator_str(op)] =
|
||||
function.get_code().operator_input_sizes_[i];
|
||||
|
@ -53,7 +53,7 @@ std::vector<mobile::nnc::InputSpec> toInputSpecs(
|
||||
auto num_inputs =
|
||||
g->inputs().size() - kernel->getSymbolicShapeInputs().size();
|
||||
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
for (const auto i : c10::irange(num_inputs)) {
|
||||
auto v = g->inputs()[i];
|
||||
const auto& t = v->type();
|
||||
mobile::nnc::InputSpec spec;
|
||||
@ -373,7 +373,7 @@ std::vector<c10::optional<at::Tensor>> generateExampleInputs(
|
||||
const std::vector<at::MemoryFormat>& inputMemoryFormats) {
|
||||
std::vector<c10::optional<at::Tensor>> example_inputs;
|
||||
example_inputs.reserve(inputShapes.size());
|
||||
for (int i = 0; i < inputShapes.size(); ++i) {
|
||||
for (const auto i : c10::irange(inputShapes.size())) {
|
||||
const auto dtype = at::dtype(inputTypes[i]);
|
||||
const auto memory_format = inputMemoryFormats[i];
|
||||
example_inputs.emplace_back(
|
||||
|
@ -45,7 +45,7 @@ bool InputSpec::validate(const at::Tensor& input) const {
|
||||
return false;
|
||||
}
|
||||
auto spec_sizes = sizes_;
|
||||
for (int i = 0; i < spec_sizes.size(); i++) {
|
||||
for (const auto i : c10::irange(spec_sizes.size())) {
|
||||
// InputSpec size 0 means that the dimension is dynamic
|
||||
if (spec_sizes[i] != 0 && spec_sizes[i] != input.sizes()[i]) {
|
||||
return false;
|
||||
|
@ -89,8 +89,8 @@ void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
|
||||
// algorithm, because the number of upgrader per operator will be just a
|
||||
// few and tend to keep the code light-weight from binary size concern.
|
||||
for (const auto& upgrader : upgrader_list) {
|
||||
if (operator_version <= upgrader.max_version &&
|
||||
operator_version >= upgrader.min_version) {
|
||||
if (static_cast<int>(operator_version) <= upgrader.max_version &&
|
||||
static_cast<int>(operator_version) >= upgrader.min_version) {
|
||||
// If there exists a valid upgrader, change the instruction OP to
|
||||
// CALL, and the index will point to the according upgrader
|
||||
// function. All upgrader function are available in
|
||||
@ -102,7 +102,7 @@ void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
|
||||
// new_inst.X = upgrader.index;
|
||||
// code->instructions_[i] = new_inst;
|
||||
TORCH_CHECK(
|
||||
upgrader.index < code.functions_.size(),
|
||||
upgrader.index < static_cast<int>(code.functions_.size()),
|
||||
"upgrader index is, ",
|
||||
upgrader.index,
|
||||
" and it's larger than the upgrader function list length ",
|
||||
|
@ -22,7 +22,7 @@ c10::optional<UpgraderEntry> findUpgrader(
|
||||
upgraders_for_schema.begin(),
|
||||
upgraders_for_schema.end(),
|
||||
[current_version](const UpgraderEntry& entry) {
|
||||
return entry.bumped_at_version > current_version;
|
||||
return entry.bumped_at_version > static_cast<int>(current_version);
|
||||
});
|
||||
|
||||
if (pos != upgraders_for_schema.end()) {
|
||||
@ -36,7 +36,7 @@ bool isOpCurrentBasedOnUpgraderEntries(
|
||||
size_t current_version) {
|
||||
auto latest_update =
|
||||
upgraders_for_schema[upgraders_for_schema.size() - 1].bumped_at_version;
|
||||
if (latest_update > current_version) {
|
||||
if (latest_update > static_cast<int>(current_version)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -504,7 +504,7 @@ size_t determineUsageIdx(Value* value, Node* user) {
|
||||
const auto idx =
|
||||
std::find(user->inputs().begin(), user->inputs().end(), value) -
|
||||
user->inputs().begin();
|
||||
TORCH_CHECK(idx != user->inputs().size());
|
||||
TORCH_CHECK(idx != static_cast<decltype(idx)>(user->inputs().size()));
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
@ -80,7 +80,7 @@ bool isZerodimCPUTensor(std::shared_ptr<TensorType> tensor_type) {
|
||||
bool propWithNoDevice(Node* n) {
|
||||
// Propagate if we can verify that all input devices match,
|
||||
// except CPU zerodim, which any other type can overwrite
|
||||
int input_num = 0;
|
||||
size_t input_num = 0;
|
||||
|
||||
for (; input_num < n->inputs().size(); input_num++) {
|
||||
if (n->inputs()[input_num]->type()->cast<TensorType>()) {
|
||||
@ -124,7 +124,7 @@ bool defaultDeviceProp(Node* n) {
|
||||
return false;
|
||||
}
|
||||
auto arguments = schema->arguments();
|
||||
for (int i = 0; i < arguments.size(); i++) {
|
||||
for (size_t i = 0; i < arguments.size(); i++) {
|
||||
Argument& argument = arguments[i];
|
||||
if (DeviceObjType::get()->isSubtypeOf(argument.type())) {
|
||||
// Optional args are filled in by torchscript with default val
|
||||
|
@ -52,7 +52,7 @@ void insertPrePackedConvOpForNode(Node* n) {
|
||||
Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"), 1);
|
||||
|
||||
// skip input value
|
||||
for (auto i = 1; i < n->inputs().size(); i++) {
|
||||
for (const auto i : c10::irange(1, n->inputs().size())) {
|
||||
Value* v = n->input(i);
|
||||
prepack_node->addInput(v);
|
||||
}
|
||||
|
@ -1255,8 +1255,8 @@ class ShapePropagator : public PropertyPropBase {
|
||||
} else if (upcast_integer && !at::isFloatingType(*type->scalarType())) {
|
||||
type = type->withScalarType(at::kLong);
|
||||
}
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
if (*type->dim() >= num_reduced_dim && num_reduced_dim > 0) {
|
||||
if (static_cast<int64_t>(*type->dim()) >= num_reduced_dim &&
|
||||
num_reduced_dim > 0) {
|
||||
return {type->withDim(*type->dim() - num_reduced_dim)};
|
||||
} else {
|
||||
return {std::move(type)};
|
||||
|
@ -156,7 +156,7 @@ std::ostream& operator<<(std::ostream& os, const ShapeArguments& sa) {
|
||||
}
|
||||
|
||||
os << "(";
|
||||
for (size_t i = 0; i < sa.len(); i++) {
|
||||
for (const auto i : c10::irange(sa.len())) {
|
||||
os << sa.at(i);
|
||||
}
|
||||
os << ")";
|
||||
@ -422,8 +422,8 @@ struct SymbolicShapeOpAnalyzer {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
inputs_.size() >= shape_compute_graph_->inputs().size(),
|
||||
"Missing Arg for Shape Graph");
|
||||
for (int64_t index = 0; index < shape_compute_graph_->inputs().size();
|
||||
index++) {
|
||||
for (const auto index :
|
||||
c10::irange(shape_compute_graph_->inputs().size())) {
|
||||
auto shape_arguments = c10::get_if<ShapeArguments>(&inputs_[index]);
|
||||
if (!shape_arguments || !shape_arguments->has_dim()) {
|
||||
continue;
|
||||
@ -551,7 +551,7 @@ struct SymbolicShapeOpAnalyzer {
|
||||
std::vector<c10::SymbolicShape> propagateShapesInGraph() {
|
||||
bool made_change = true;
|
||||
constexpr size_t MAX_ATTEMPTS = 8;
|
||||
for (int attempt_num = 0; made_change && attempt_num < MAX_ATTEMPTS;
|
||||
for (unsigned attempt_num = 0; made_change && attempt_num < MAX_ATTEMPTS;
|
||||
attempt_num++) {
|
||||
// symbolic shape concrete values are only used in final shape extraction
|
||||
GRAPH_DUMP("Before substitution: ", shape_compute_graph_);
|
||||
@ -793,7 +793,7 @@ c10::SymbolicShape combine_bounds(
|
||||
return c10::SymbolicShape();
|
||||
}
|
||||
std::vector<c10::ShapeSymbol> merged_shapes;
|
||||
for (int i = 0; i < lower_bound.rank(); i++) {
|
||||
for (const auto i : c10::irange(*lower_bound.rank())) {
|
||||
// TODO: Merge equivalent expressions (not needed for current use case)
|
||||
if (lower_bound[i] == upper_bound[i]) {
|
||||
merged_shapes.push_back(lower_bound[i]);
|
||||
|
@ -522,8 +522,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
INST_NEXT;
|
||||
case INST(TYPECHECK): {
|
||||
INST_GUARD;
|
||||
int num_inputs = inst.N, i = 0;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
unsigned num_inputs = inst.N, i = 0;
|
||||
TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
|
||||
// Check every input's shape against profiled (expected) shape.
|
||||
for (i = 0; i < num_inputs; i++) {
|
||||
|
@ -865,8 +865,8 @@ struct CodeImpl {
|
||||
static_cast<int64_t>(schema.arguments().size()) +
|
||||
static_cast<int64_t>(schema.returns().size());
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
expected_size == actual_size || schema.is_varret() ||
|
||||
schema.is_vararg(),
|
||||
static_cast<size_t>(expected_size) == actual_size ||
|
||||
schema.is_varret() || schema.is_vararg(),
|
||||
"Expected to find ",
|
||||
expected_size,
|
||||
" values on the stack, but found ",
|
||||
|
@ -407,7 +407,7 @@ void StandardMemoryPlanner::allocateManagedTensors() {
|
||||
auto* start = allocateBuffer(managed_bytes_);
|
||||
|
||||
reused_tensors_ = 0;
|
||||
auto group_idx = 0;
|
||||
size_t group_idx = 0;
|
||||
for (const size_t storages_idx : c10::irange(storages_.size())) {
|
||||
auto tensor_size = storages_nbytes_[storages_idx];
|
||||
if (tensor_size == 0) {
|
||||
@ -444,7 +444,7 @@ void StandardMemoryPlanner::deallocateManagedTensors() {
|
||||
// We don't have any guarantee that the model doesn't change the
|
||||
// Storage for managed tensors out from under us during execution,
|
||||
// so we have to check the Storages each time we deallocate.
|
||||
auto group_idx = 0;
|
||||
unsigned group_idx = 0;
|
||||
const bool first_time = storages_.empty();
|
||||
if (C10_UNLIKELY(first_time)) {
|
||||
if (storages_.is_allocated()) {
|
||||
|
@ -127,7 +127,9 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
return nullptr;
|
||||
}
|
||||
return [](ProcessedNode* p_node) {
|
||||
DCHECK(p_node->num_inputs() - 1 == p_node->outputs().size());
|
||||
DCHECK(
|
||||
static_cast<size_t>(p_node->num_inputs() - 1) ==
|
||||
p_node->outputs().size());
|
||||
auto dict = p_node->Input(0).toGenericDict();
|
||||
const auto num_inputs = p_node->num_inputs();
|
||||
for (size_t i = 1; i < num_inputs; ++i) {
|
||||
@ -1252,7 +1254,8 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
const auto num_elems = elems.size();
|
||||
const auto idx = pnode->Input(1).toInt();
|
||||
const auto norm_idx = normalizeIndex(idx, num_elems);
|
||||
if (norm_idx < 0 || norm_idx >= num_elems) {
|
||||
if (norm_idx < 0 ||
|
||||
norm_idx >= static_cast<decltype(norm_idx)>(num_elems)) {
|
||||
// Use std::runtime_error instead of c10::Error to be consistent with
|
||||
// JIT
|
||||
throw std::out_of_range("Tuple index out of range");
|
||||
|
@ -217,7 +217,7 @@ void einsum(Stack& stack, size_t num_inputs) {
|
||||
void percentFormat(Stack& stack, size_t num_inputs) {
|
||||
auto format_str = peek(stack, 0, num_inputs).toStringRef();
|
||||
auto args = last(stack, num_inputs - 1)[0];
|
||||
auto args_size = 1; // assumed size
|
||||
size_t args_size = 1; // assumed size
|
||||
if (args.isTuple()) {
|
||||
args_size = args.toTupleRef().elements().size();
|
||||
}
|
||||
@ -239,7 +239,6 @@ void percentFormat(Stack& stack, size_t num_inputs) {
|
||||
begin = percent_idx + 2; // skip the `%` and the format specifier
|
||||
continue;
|
||||
}
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
TORCH_CHECK(used_args < args_size, "Too few arguments for format string");
|
||||
char key = format_str.at(format_idx);
|
||||
IValue arg;
|
||||
@ -252,7 +251,6 @@ void percentFormat(Stack& stack, size_t num_inputs) {
|
||||
begin = percent_idx + 2;
|
||||
++used_args;
|
||||
}
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
TORCH_CHECK(used_args == args_size, "Too many arguments for format string");
|
||||
drop(stack, num_inputs);
|
||||
push(stack, ss.str());
|
||||
|
@ -503,7 +503,8 @@ GraphEncoder::GraphEncoder(
|
||||
model_proto_.set_producer_name("pytorch");
|
||||
TORCH_CHECK(
|
||||
onnx_opset_version > 0 &&
|
||||
onnx_opset_version < kOpsetVersionToIRVersion.size() &&
|
||||
static_cast<size_t>(onnx_opset_version) <
|
||||
kOpsetVersionToIRVersion.size() &&
|
||||
kOpsetVersionToIRVersion[onnx_opset_version] != kInvalidOpsetVersion,
|
||||
"Unsupported onnx_opset_version: ",
|
||||
onnx_opset_version);
|
||||
|
@ -269,7 +269,7 @@ IValue convertMobileFunctionToCodeTable(
|
||||
|
||||
std::vector<IValue> operators;
|
||||
operators.reserve(code.op_names_.size());
|
||||
for (int i = 0; i < code.op_names_.size(); ++i) {
|
||||
for (unsigned i = 0; i < code.op_names_.size(); ++i) {
|
||||
const auto& opname = code.op_names_[i];
|
||||
const int size = code.operator_input_sizes_[i];
|
||||
if (compilation_options.enable_default_value_for_unspecified_arg) {
|
||||
|
@ -176,7 +176,7 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
// operators
|
||||
std::vector<IValue> operators;
|
||||
operators.reserve(mobile_code.op_names_.size());
|
||||
for (int i = 0; i < mobile_code.op_names_.size(); ++i) {
|
||||
for (const auto i : c10::irange(mobile_code.op_names_.size())) {
|
||||
const auto& opname = mobile_code.op_names_[i];
|
||||
const int size = mobile_code.operator_input_sizes_[i];
|
||||
if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
|
||||
|
@ -235,7 +235,7 @@ flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
|
||||
std::vector<flatbuffers::Offset<mobile::serialization::Operator>>
|
||||
operator_vector;
|
||||
operator_vector.reserve(code.op_names_.size());
|
||||
for (int i = 0; i < code.op_names_.size(); ++i) {
|
||||
for (const auto i : c10::irange(code.op_names_.size())) {
|
||||
const auto& opname = code.op_names_[i];
|
||||
const int op_size = code.operator_input_sizes_[i];
|
||||
operator_vector.push_back(CreateOperator(
|
||||
|
@ -923,7 +923,7 @@ struct PythonPrintImpl {
|
||||
if (val.isString()) {
|
||||
const auto maxASCII = 0x7fu;
|
||||
for (auto c : val.toStringRef()) {
|
||||
if (c > maxASCII) {
|
||||
if (static_cast<decltype(maxASCII)>(c) > maxASCII) {
|
||||
hasNonASCII = true;
|
||||
return true;
|
||||
}
|
||||
@ -1235,7 +1235,7 @@ struct PythonPrintImpl {
|
||||
auto num_necessary = specified_args.first;
|
||||
auto num_out = specified_args.second;
|
||||
|
||||
for (size_t i = 0; i < num_necessary; ++i) {
|
||||
for (const auto i : c10::irange(static_cast<size_t>(num_necessary))) {
|
||||
if (i > 0)
|
||||
stmt << ", ";
|
||||
auto v = useOf(node->inputs().at(i));
|
||||
@ -1745,7 +1745,7 @@ void jitModuleToPythonCodeAndConstants(
|
||||
class_deps.add(class_type);
|
||||
}
|
||||
|
||||
for (int i = 0; i < class_deps.size(); ++i) {
|
||||
for (const auto i : c10::irange(class_deps.size())) {
|
||||
auto type = class_deps[i];
|
||||
auto qualname = uniquer.getUniqueName(type);
|
||||
std::string qualifier = qualname.prefix();
|
||||
|
@ -119,9 +119,8 @@ void CudaAnalysis::visit(ForPtr v) {
|
||||
throw std::runtime_error("support only 3D gpu_block_index");
|
||||
}
|
||||
ExprPtr prev = nullptr;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
if (gpu_block_extents_.size() <= gpu_block_index) {
|
||||
if (gpu_block_extents_.size() <= static_cast<size_t>(gpu_block_index)) {
|
||||
gpu_block_extents_.resize(gpu_block_index + 1);
|
||||
} else {
|
||||
prev = gpu_block_extents_[gpu_block_index];
|
||||
@ -149,9 +148,8 @@ void CudaAnalysis::visit(ForPtr v) {
|
||||
throw std::runtime_error("support only 3D gpu_thread_index");
|
||||
}
|
||||
ExprPtr prev = nullptr;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
if (gpu_thread_extents_.size() <= gpu_thread_index) {
|
||||
if (gpu_thread_extents_.size() <= static_cast<size_t>(gpu_thread_index)) {
|
||||
gpu_thread_extents_.resize(gpu_thread_index + 1);
|
||||
} else {
|
||||
prev = gpu_thread_extents_[gpu_thread_index];
|
||||
@ -503,8 +501,7 @@ class PrioritizeLoad : public IRMutator {
|
||||
v->indices().size() == nested_store_->indices().size()) {
|
||||
// also check indices
|
||||
bool same = true;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
for (int i = 0; i < v->indices().size(); ++i) {
|
||||
for (const auto i : c10::irange(v->indices().size())) {
|
||||
if (!exprEquals(v->indices()[i], nested_store_->indices()[i])) {
|
||||
same = false;
|
||||
break;
|
||||
|
@ -688,7 +688,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
throw malformed_input(
|
||||
"Number of dimensions did not match number of strides", buf);
|
||||
}
|
||||
size_t buf_size = 1;
|
||||
int64_t buf_size = 1;
|
||||
if (!dims.empty()) {
|
||||
ExprHandle buf_size_expr = ExprHandle(immLike(dims[0], 1));
|
||||
ExprHandle negative_one = ExprHandle(immLike(dims[0], -1));
|
||||
|
@ -427,7 +427,7 @@ bool trimGraphOnce(const std::shared_ptr<Graph>& graph) {
|
||||
std::unordered_set<Value*> outputs(
|
||||
graph->outputs().begin(), graph->outputs().end());
|
||||
bool changed = false;
|
||||
for (int idx = 0; idx < ret->inputs().size(); idx++) {
|
||||
for (size_t idx = 0; idx < ret->inputs().size(); idx++) {
|
||||
auto v = ret->inputs()[idx];
|
||||
if (graph_inputs.count(v)) {
|
||||
continue;
|
||||
|
@ -1089,7 +1089,7 @@ std::vector<ExprHandle> TensorExprKernel::getInputStrides(
|
||||
generated_strides++;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < rank; i++) {
|
||||
for (int i = 0; i < static_cast<int>(rank); i++) {
|
||||
if (stride_input[i] == torch::jit::StrideInput::S_TRAN_CONT &&
|
||||
stride_set[i - 1]) {
|
||||
inputTensorStrides[i] =
|
||||
@ -1500,7 +1500,7 @@ BlockPtr TensorExprKernel::bindAllInputs() {
|
||||
//
|
||||
// TODO: Check if the tensors with symbolic shapes are contiguous.
|
||||
TORCH_CHECK(
|
||||
nInputs_ > symbolic_shape_inputs_.size(),
|
||||
nInputs_ > static_cast<int64_t>(symbolic_shape_inputs_.size()),
|
||||
"Symbolic dims not provided as inputs to the graph");
|
||||
|
||||
// First, process the symbolic input params and create a new variable for
|
||||
@ -1510,7 +1510,9 @@ BlockPtr TensorExprKernel::bindAllInputs() {
|
||||
// create for the symbolic input params.
|
||||
symbolic_shape_args.reserve(symbolic_shape_inputs_.size());
|
||||
|
||||
for (size_t i = symbolic_shape_inputs_start_pos; i < nInputs_; ++i) {
|
||||
for (size_t i = symbolic_shape_inputs_start_pos;
|
||||
i < static_cast<size_t>(nInputs_);
|
||||
++i) {
|
||||
auto input = graph_->inputs()[i];
|
||||
if (input->type()->kind() != TypeKind::IntType) {
|
||||
throw std::runtime_error(
|
||||
@ -2104,7 +2106,8 @@ void TensorExprKernel::runWithAllocatedOutputs(Stack& stack) const {
|
||||
args.emplace_back(&stride_values[idx]);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(nOutputs_ == bufOutputs_.size());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
nOutputs_ == static_cast<int64_t>(bufOutputs_.size()));
|
||||
for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
|
||||
auto& out = stack_outputs[i].toTensor();
|
||||
// This has only been tested on CPUs.
|
||||
|
@ -68,7 +68,7 @@ std::tuple<std::vector<T>, std::vector<int>> select_n_randomly(
|
||||
|
||||
std::vector<T> selected_objects;
|
||||
std::vector<int> selected_indices;
|
||||
if (indices.size() < n) {
|
||||
if (static_cast<int>(indices.size()) < n) {
|
||||
return std::make_tuple(selected_objects, selected_indices);
|
||||
}
|
||||
for (int i = 0; i < n; i++) {
|
||||
@ -393,8 +393,8 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
|
||||
|
||||
// Find pairs of axes that can be reordered
|
||||
std::vector<std::pair<ForPtr, ForPtr>> valid_pairs;
|
||||
for (int i = 0; i < loops.size(); i++) {
|
||||
for (int j = i + 1; j < loops.size(); j++) {
|
||||
for (const auto i : c10::irange(loops.size())) {
|
||||
for (const auto j : c10::irange(i + 1, loops.size())) {
|
||||
if (LoopNest::findOuterFor(loops[i], loops[j])) {
|
||||
valid_pairs.emplace_back(loops[i], loops[j]);
|
||||
}
|
||||
|
@ -338,7 +338,7 @@ Tensor computeChunk(
|
||||
size_t step = buf_info->dims[norm_dim] / chunks;
|
||||
|
||||
std::vector<ExprHandle> new_indices;
|
||||
for (int64_t i = 0; i < indices.size(); ++i) {
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(indices.size()); ++i) {
|
||||
if (i == norm_dim) {
|
||||
new_indices.push_back(
|
||||
indices[i] + ExprHandle(immLike(indices[i], chunkIdx * step)));
|
||||
@ -574,7 +574,7 @@ Tensor computeCatWoConditionals(
|
||||
std::vector<VarPtr> for_vars(dims.size());
|
||||
std::vector<ExprPtr> load_indices(dims.size());
|
||||
std::vector<ExprPtr> store_indices(dims.size());
|
||||
for (int64_t i = 0; i < dims.size(); ++i) {
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(dims.size()); ++i) {
|
||||
for_vars[i] = alloc<Var>(
|
||||
"i" + c10::to_string(inp_pos) + "_" + c10::to_string(i),
|
||||
dims[i].dtype());
|
||||
|
@ -126,8 +126,7 @@ Tensor computeMean(
|
||||
extra_args = c10::fmap<ExprHandle>(*mean_dims);
|
||||
} else {
|
||||
// When dims argument is not specified, reduce over all dimensions
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
for (int64_t idx = 0; idx < InputBuf.ndim(); idx++) {
|
||||
for (int64_t idx = 0; idx < static_cast<int64_t>(InputBuf.ndim()); ++idx) {
|
||||
extra_args.emplace_back(idx);
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,8 @@ std::vector<int64_t> DropDimensions(
|
||||
std::vector<int64_t> new_dims;
|
||||
size_t drop_index = 0;
|
||||
for (const auto i : c10::irange(sizes.size())) {
|
||||
if (drop_index < drop_dims.size() && i == drop_dims[drop_index]) {
|
||||
if (drop_index < drop_dims.size() &&
|
||||
static_cast<int64_t>(i) == drop_dims[drop_index]) {
|
||||
++drop_index;
|
||||
} else {
|
||||
new_dims.push_back(sizes[i]);
|
||||
|
@ -668,7 +668,7 @@ std::vector<torch::lazy::BackendDataPtr> LazyGraphExecutor::SetTensorData(
|
||||
const std::vector<BackendDataPtr>& tensor_data_vec) {
|
||||
std::vector<BackendDataPtr> tensors_data;
|
||||
tensors_data.reserve(indices.size());
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
for (const auto i : c10::irange(indices.size())) {
|
||||
auto index = indices[i];
|
||||
LazyTensorPtr& tensor = (*tensors)[index];
|
||||
// If the config.force_ltc_data flag is true, the purpose of this tensor
|
||||
@ -783,7 +783,8 @@ LazyGraphExecutor::CompilationResult LazyGraphExecutor::Compile(
|
||||
// TODO(whc) should computation be allowed null here? (because it is in one
|
||||
// case)
|
||||
TORCH_CHECK(
|
||||
computation->parameters_size() == po_data->parameters_data.size());
|
||||
computation->parameters_size() ==
|
||||
static_cast<int>(po_data->parameters_data.size()));
|
||||
}
|
||||
|
||||
return {
|
||||
|
@ -321,7 +321,7 @@ std::string MetricFnValue(double value) {
|
||||
std::string MetricFnBytes(double value) {
|
||||
static const std::array<const char*, 6> kSizeSuffixes{
|
||||
"B", "KB", "MB", "GB", "TB", "PB"};
|
||||
int sfix = 0;
|
||||
unsigned sfix = 0;
|
||||
for (; (sfix + 1) < kSizeSuffixes.size() && value >= 1024.0; ++sfix) {
|
||||
value /= 1024.0;
|
||||
}
|
||||
|
@ -82,7 +82,8 @@ std::vector<int64_t> BuildSqueezedDimensions(
|
||||
std::vector<int64_t> output_dimensions;
|
||||
for (const auto i : c10::irange(dimensions.size())) {
|
||||
int64_t dim = dimensions[i];
|
||||
if (dim != 1 || (i != squeeze_dim && squeeze_dim >= 0)) {
|
||||
if (dim != 1 ||
|
||||
(static_cast<int64_t>(i) != squeeze_dim && squeeze_dim >= 0)) {
|
||||
output_dimensions.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ c10::SymbolicShape get_symbolic_shape(at::Tensor& tensor) {
|
||||
sizes.size() == is_symbolic->size(),
|
||||
"Dims of two values are not consistent");
|
||||
std::vector<c10::optional<int64_t>> symbolic_dims;
|
||||
for (int64_t i = 0; i < sizes.size(); i++) {
|
||||
for (size_t i = 0; i < sizes.size(); i++) {
|
||||
if (is_symbolic->at(i)) {
|
||||
symbolic_dims.emplace_back(c10::nullopt);
|
||||
} else {
|
||||
@ -120,7 +120,7 @@ void applySymbolicShapesOnLT(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
res_symbolic->size() == result_shapes.size(),
|
||||
"Result shape size is not consistent");
|
||||
for (int64_t i = 0; i < res_symbolic->size(); i++) {
|
||||
for (size_t i = 0; i < res_symbolic->size(); i++) {
|
||||
auto sym_dims = res_symbolic->at(i).symbolicDims();
|
||||
if (sym_dims.has_value()) {
|
||||
result_shapes[i] = result_shapes[i].with_symbolic_dims(*sym_dims);
|
||||
|
@ -34,7 +34,7 @@
|
||||
*
|
||||
* 3. How to figure out the shape/dtype
|
||||
* ------------------------------------
|
||||
* Unfortunatley there isn't a one-stop-shop for learning the output shape
|
||||
* Unfortunately there isn't a one-stop-shop for learning the output shape
|
||||
* formulae for all operators. This is partly because some operators are not
|
||||
* part of our 'public' API, including backward operators which users don't
|
||||
* directly invoke.
|
||||
@ -427,8 +427,8 @@ std::vector<Shape> compute_shape_expand(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef size,
|
||||
bool implicit) {
|
||||
TORCH_CHECK_GE(size.size(), self.dim());
|
||||
int64_t num_new_dimensions = size.size() - self.dim();
|
||||
TORCH_CHECK_GE(static_cast<int64_t>(size.size()), self.dim());
|
||||
size_t num_new_dimensions = size.size() - self.dim();
|
||||
std::vector<int64_t> padded_self(num_new_dimensions, 0);
|
||||
padded_self.insert(
|
||||
padded_self.end(), self.sizes().begin(), self.sizes().end());
|
||||
@ -443,9 +443,9 @@ std::vector<Shape> compute_shape_expand(
|
||||
const at::Tensor& self,
|
||||
c10::SymIntArrayRef size,
|
||||
bool implicit) {
|
||||
TORCH_CHECK_GE(size.size(), self.dim());
|
||||
TORCH_CHECK_GE(static_cast<int64_t>(size.size()), self.dim());
|
||||
std::vector<c10::SymInt> _sizes = ToVector<c10::SymInt>(size);
|
||||
int64_t num_new_dimensions = _sizes.size() - self.dim();
|
||||
size_t num_new_dimensions = _sizes.size() - self.dim();
|
||||
std::vector<int64_t> padded_self(num_new_dimensions, 0);
|
||||
padded_self.insert(
|
||||
padded_self.end(), self.sizes().begin(), self.sizes().end());
|
||||
@ -1138,8 +1138,8 @@ std::vector<Shape> compute_shape_stack(at::TensorList tensors, int64_t dim) {
|
||||
std::vector<Shape> compute_shape_repeat(
|
||||
const at::Tensor& self,
|
||||
at::IntArrayRef repeats) {
|
||||
TORCH_CHECK_GE(repeats.size(), self.dim());
|
||||
int64_t num_new_dimensions = repeats.size() - self.dim();
|
||||
TORCH_CHECK_GE(static_cast<int64_t>(repeats.size()), self.dim());
|
||||
size_t num_new_dimensions = repeats.size() - self.dim();
|
||||
std::vector<int64_t> padded_size(num_new_dimensions, 1);
|
||||
padded_size.insert(
|
||||
padded_size.end(), self.sizes().begin(), self.sizes().end());
|
||||
|
@ -174,7 +174,7 @@ void LazyTensor::TryLimitGraphSize() {
|
||||
FLAGS_torch_lazy_trim_graph_check_frequency ==
|
||||
0) {
|
||||
size_t graph_size = Util::GetGraphSize({data()->ir_value.node.get()});
|
||||
if (graph_size > FLAGS_torch_lazy_trim_graph_size) {
|
||||
if (static_cast<int64_t>(graph_size) > FLAGS_torch_lazy_trim_graph_size) {
|
||||
TORCH_LAZY_COUNTER("TrimIrGraph", 1);
|
||||
ApplyPendingGraph();
|
||||
}
|
||||
|
@ -221,7 +221,7 @@ void ts_eager_fallback(
|
||||
|
||||
// Step 1: Convert all non-eager tensor inputs into eager tensors and put them
|
||||
// on the stack at the correct indices.
|
||||
for (int64_t idx = 0; idx < arguments.size(); ++idx) {
|
||||
for (size_t idx = 0; idx < arguments.size(); ++idx) {
|
||||
const auto& ivalue = arguments[idx];
|
||||
if (ivalue.isTensor()) {
|
||||
tensor_args.push_back(ivalue.toTensor());
|
||||
@ -246,7 +246,7 @@ void ts_eager_fallback(
|
||||
// CPU together.
|
||||
auto eager_tensors = to_eager(tensor_args, device_type);
|
||||
|
||||
for (auto i = 0; i < tensor_args_indices.size(); ++i) {
|
||||
for (const auto i : c10::irange(tensor_args_indices.size())) {
|
||||
auto idx = tensor_args_indices[i];
|
||||
(*stack)[arguments_begin + idx] = c10::IValue(eager_tensors[i]);
|
||||
}
|
||||
@ -257,7 +257,7 @@ void ts_eager_fallback(
|
||||
// Step 3: We need to take special care to handle mutable aliases properly:
|
||||
// If any input tensors are mutable aliases, we need to directly copy the
|
||||
// updated data on the eager tensors back to the original inputs.
|
||||
for (int64_t i = 0; i < tensor_args_indices.size(); ++i) {
|
||||
for (const auto i : c10::irange(tensor_args_indices.size())) {
|
||||
auto tensor_idx = tensor_args_indices[i];
|
||||
const auto alias_info = schema_args[tensor_idx].alias_info();
|
||||
if (alias_info != nullptr && alias_info->isWrite()) {
|
||||
@ -288,7 +288,7 @@ void ts_eager_fallback(
|
||||
auto returns = torch::jit::last(stack, num_returns);
|
||||
const auto returns_begin = stack->size() - num_returns;
|
||||
|
||||
for (int64_t idx = 0; idx < returns.size(); ++idx) {
|
||||
for (const auto idx : c10::irange(returns.size())) {
|
||||
if (returns[idx].isTensor()) {
|
||||
const auto& return_tens = returns[idx].toTensor();
|
||||
if (return_tens.defined()) {
|
||||
@ -299,7 +299,7 @@ void ts_eager_fallback(
|
||||
bool found_alias = false;
|
||||
// We could store some extra metadata on the function schema to avoid
|
||||
// the loop here if we need to improve perf.
|
||||
for (int64_t i = 0; i < tensor_args_indices.size(); ++i) {
|
||||
for (const auto i : c10::irange(tensor_args_indices.size())) {
|
||||
auto input_tensor_idx = tensor_args_indices[i];
|
||||
const auto& input_tensor = eager_tensors[i];
|
||||
const auto input_alias_info =
|
||||
|
@ -183,7 +183,7 @@ class ExperimentalConfigWrapper {
|
||||
configss << "ACTIVITIES_WARMUP_PERIOD_SECS=0\n"
|
||||
<< "CUPTI_PROFILER_METRICS=";
|
||||
|
||||
for (int i = 0; i < num_metrics; i++) {
|
||||
for (size_t i = 0; i < num_metrics; i++) {
|
||||
configss << config_.profiler_metrics[i];
|
||||
if (num_metrics > 1 && i < (num_metrics - 1)) {
|
||||
configss << ",";
|
||||
|
@ -165,7 +165,7 @@ void PerfProfiler::Enable() {
|
||||
start_values_.emplace(events_.size(), 0);
|
||||
|
||||
auto& sv = start_values_.top();
|
||||
for (int i = 0; i < events_.size(); ++i) {
|
||||
for (unsigned i = 0; i < events_.size(); ++i) {
|
||||
sv[i] = events_[i].ReadCounter();
|
||||
}
|
||||
StartCounting();
|
||||
@ -182,7 +182,7 @@ void PerfProfiler::Disable(perf_counters_t& vals) {
|
||||
/* Always connecting this disable event to the last enable event i.e. using
|
||||
* whatever is on the top of the start counter value stack. */
|
||||
perf_counters_t& sv = start_values_.top();
|
||||
for (int i = 0; i < events_.size(); ++i) {
|
||||
for (unsigned i = 0; i < events_.size(); ++i) {
|
||||
vals[i] = CalcDelta(sv[i], events_[i].ReadCounter());
|
||||
}
|
||||
start_values_.pop();
|
||||
|
Reference in New Issue
Block a user