mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66742 Modified loops in files under fbsource/fbcode/caffe2/ from the format `for(TYPE var=x0;var<x_max;x++)` to the format `for(const auto var: irange(xmax))` This was achieved by running r-barnes's loop upgrader script (D28874212) with some modification to exclude all files under /torch/jit and a number of reversions or unused variable suppression warnings added by hand. Test Plan: Sandcastle Reviewed By: malfet Differential Revision: D31705366 fbshipit-source-id: be58222426c192406a7f93c21582c3f6f2082401
59 lines
1.6 KiB
C++
59 lines
1.6 KiB
C++
#ifndef CAFFE2_OPERATORS_FIND_DUPLICATE_ELEMENTS_OP_H
|
|
#define CAFFE2_OPERATORS_FIND_DUPLICATE_ELEMENTS_OP_H
|
|
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/core/tensor.h"
|
|
#include "c10/util/irange.h"
|
|
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
namespace caffe2 {
|
|
|
|
template <class Context>
|
|
class FindDuplicateElementsOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
USE_SIMPLE_CTOR_DTOR(FindDuplicateElementsOp);
|
|
USE_DISPATCH_HELPER;
|
|
|
|
bool RunOnDevice() override {
|
|
return DispatchHelper<TensorTypes<float, double, int, long, std::string>>::
|
|
call(this, Input(0));
|
|
}
|
|
|
|
template <typename T>
|
|
bool DoRunWithType() {
|
|
const auto& data = Input(0);
|
|
CAFFE_ENFORCE(data.dim() == 1, "data should be 1-D.");
|
|
|
|
const auto* data_ptr = data.template data<T>();
|
|
std::unordered_map<T, int64_t> dict;
|
|
std::vector<int64_t> dupIndices;
|
|
// i is the index of unique elements, j is the index of all elements
|
|
for (int64_t i = 0, j = 0; j < data.sizes()[0]; ++i, ++j) {
|
|
bool retVal = dict.insert({data_ptr[j], i}).second;
|
|
if (!retVal) {
|
|
--i;
|
|
dupIndices.push_back(j);
|
|
}
|
|
}
|
|
|
|
const auto dupSize = dupIndices.size();
|
|
|
|
auto* output =
|
|
Output(0, {static_cast<int64_t>(dupSize)}, at::dtype<int64_t>());
|
|
auto* out_ptr = output->template mutable_data<int64_t>();
|
|
for (const auto i : c10::irange(dupSize)) {
|
|
out_ptr[i] = dupIndices[i];
|
|
}
|
|
|
|
return true;
|
|
}
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPERATORS_FIND_DUPLICATE_ELEMENTS_OP_H
|