mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66742 Modified loops in files under fbsource/fbcode/caffe2/ from the format `for(TYPE var=x0;var<x_max;x++)` to the format `for(const auto var: irange(xmax))` This was achieved by running r-barnes's loop upgrader script (D28874212) with some modification to exclude all files under /torch/jit and a number of reversions or unused variable suppression warnings added by hand. Test Plan: Sandcastle Reviewed By: malfet Differential Revision: D31705366 fbshipit-source-id: be58222426c192406a7f93c21582c3f6f2082401
152 lines
4.1 KiB
C++
152 lines
4.1 KiB
C++
#pragma once
|
|
|
|
#include "caffe2/core/operator.h"
|
|
#include "c10/util/irange.h"
|
|
|
|
#include <cmath>
|
|
#include <limits>
|
|
|
|
namespace caffe2 {
|
|
|
|
template <typename Context>
|
|
class QuantileOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
QuantileOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws),
|
|
quantile_(this->template GetSingleArgument<float>("quantile", -1.0)),
|
|
abs_(this->template GetSingleArgument<bool>("abs", true)),
|
|
tol_(this->template GetSingleArgument<float>("tol", 1e-3)) {
|
|
CAFFE_ENFORCE_GE(
|
|
quantile_,
|
|
0,
|
|
"input quantile should be ",
|
|
"no less than 0, got ",
|
|
quantile_);
|
|
CAFFE_ENFORCE_GE(
|
|
1.0f,
|
|
quantile_,
|
|
"input quantile should be ",
|
|
"no larger than 1, got ",
|
|
quantile_);
|
|
CAFFE_ENFORCE_GT(
|
|
tol_, 0, "tolerance should be ", "no less than 0, got ", tol_);
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0));
|
|
}
|
|
|
|
template <typename T>
|
|
bool DoRunWithType() {
|
|
Output(QUANTILE_VAL)->Resize(1);
|
|
auto* quantile_val = Output(QUANTILE_VAL)->template mutable_data<T>();
|
|
|
|
auto& input_zero = Input(0);
|
|
int64_t numel = input_zero.numel();
|
|
for (const auto i : c10::irange(1, InputSize())) {
|
|
CAFFE_ENFORCE_EQ(
|
|
Input(i).dtype(),
|
|
input_zero.dtype(),
|
|
"All inputs must have the same type, expected: ",
|
|
input_zero.dtype().name(),
|
|
" but got: ",
|
|
Input(i).dtype().name(),
|
|
" for input: ",
|
|
i);
|
|
numel += Input(i).numel();
|
|
}
|
|
CAFFE_ENFORCE_GT(
|
|
numel,
|
|
0,
|
|
"number of total element in input tensor should be ",
|
|
"larger than 0, got ",
|
|
numel);
|
|
|
|
// the expected number of elements lessEq to the target value
|
|
const int64_t target_cnt =
|
|
static_cast<int64_t>(std::ceil(numel * quantile_));
|
|
|
|
T hi = 0.0;
|
|
T lo = 0.0;
|
|
GetRangeFromInputs(&lo, &hi);
|
|
if (target_cnt == 0) {
|
|
// lowest possible value
|
|
quantile_val[0] = lo;
|
|
return true;
|
|
}
|
|
if (target_cnt == numel) {
|
|
// highest possible value
|
|
quantile_val[0] = hi;
|
|
return true;
|
|
}
|
|
int64_t lo_cnt = CountLowerEq(lo);
|
|
if (lo_cnt >= target_cnt) {
|
|
// the target is one of the lowest value
|
|
quantile_val[0] = lo;
|
|
return true;
|
|
}
|
|
while (std::abs(hi - lo) > tol_ * (std::abs(hi) + std::abs(lo))) {
|
|
// keep hi_cnt > target_idx and lo_cnt <= target_idx
|
|
const T mid = lo + (hi - lo) / 2.0;
|
|
const int64_t mid_cnt = CountLowerEq(mid);
|
|
if (mid_cnt > target_cnt) {
|
|
CAFFE_ENFORCE_NE(
|
|
hi, mid, "numeric precision at limit, unable to continue bisect");
|
|
hi = mid;
|
|
} else if (mid_cnt < target_cnt) {
|
|
CAFFE_ENFORCE_NE(
|
|
lo, mid, "numeric precision at limit, unable to continue bisect");
|
|
lo = mid;
|
|
} else {
|
|
// mid_cnt == target_cnt
|
|
quantile_val[0] = mid;
|
|
return true;
|
|
}
|
|
}
|
|
quantile_val[0] = hi;
|
|
return true;
|
|
}
|
|
|
|
protected:
|
|
float quantile_;
|
|
bool abs_;
|
|
float tol_;
|
|
OUTPUT_TAGS(QUANTILE_VAL);
|
|
|
|
template <typename T>
|
|
void GetRangeFromInputs(T* lo, T* hi) {
|
|
*hi = std::numeric_limits<T>::lowest();
|
|
*lo = std::numeric_limits<T>::max();
|
|
for (const auto i : c10::irange(InputSize())) {
|
|
const auto* input = Input(i).template data<T>();
|
|
for (const auto j : c10::irange(Input(i).numel())) {
|
|
const T val = abs_ ? std::abs(input[j]) : input[j];
|
|
if (*hi < val) {
|
|
*hi = val;
|
|
}
|
|
if (*lo > val) {
|
|
*lo = val;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
int64_t CountLowerEq(const T& thd) {
|
|
int64_t count = 0;
|
|
for (const auto i : c10::irange(InputSize())) {
|
|
const auto* input = Input(i).template data<T>();
|
|
for (const auto j : c10::irange(Input(i).numel())) {
|
|
const T val = abs_ ? std::abs(input[j]) : input[j];
|
|
if (val <= thd) {
|
|
count++;
|
|
}
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
};
|
|
|
|
} // namespace caffe2
|