Implement rotated generate_proposals_op without opencv dependency (CPU version)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18533

Reviewed By: ezyang

Differential Revision: D14648083

fbshipit-source-id: e53e8f537100862f8015c4efa4efe4d387cef551
This commit is contained in:
Jing Huang
2019-03-28 16:58:54 -07:00
committed by Facebook Github Bot
parent 1ae2c1950c
commit 11ac0cf276
4 changed files with 245 additions and 215 deletions

View File

@ -413,7 +413,6 @@ TEST(GenerateProposalsTest, TestRealDownSampled) {
1e-4);
}
#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
// Similar to TestRealDownSampled but for rotated boxes with angle info.
const float angle = 0;
@ -522,7 +521,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
ERMatXf rois_gt(rois_gt_xyxy.rows(), 6);
// Batch ID
rois_gt.block(0, 0, rois_gt.rows(), 1) =
rois_gt_xyxy.block(0, 0, rois_gt.rows(), 0);
rois_gt_xyxy.block(0, 0, rois_gt.rows(), 1);
// rois_gt in [x_ctr, y_ctr, w, h] format
rois_gt.block(0, 1, rois_gt.rows(), 4) = utils::bbox_xyxy_to_ctrwh(
rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4).array());
@ -721,6 +720,5 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) {
EXPECT_LE(std::abs(rois_data(i, 5) - expected_angle), 1e-4);
}
}
#endif // CV_MAJOR_VERSION >= 3
} // namespace caffe2

View File

@ -169,274 +169,296 @@ std::vector<int> soft_nms_cpu_upright(
return keep;
}
#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
namespace {
const int INTERSECT_NONE = 0;
const int INTERSECT_PARTIAL = 1;
const int INTERSECT_FULL = 2;
class RotatedRect {
public:
RotatedRect() {}
RotatedRect(
const Eigen::Vector2f& p_center,
const Eigen::Vector2f& p_size,
float p_angle)
: center(p_center), size(p_size), angle(p_angle) {}
void get_vertices(Eigen::Vector2f* pt) const {
// M_PI / 180. == 0.01745329251
double _angle = angle * 0.01745329251;
float b = (float)cos(_angle) * 0.5f;
float a = (float)sin(_angle) * 0.5f;
pt[0].x() = center.x() - a * size.y() - b * size.x();
pt[0].y() = center.y() + b * size.y() - a * size.x();
pt[1].x() = center.x() + a * size.y() - b * size.x();
pt[1].y() = center.y() - b * size.y() - a * size.x();
pt[2] = 2 * center - pt[0];
pt[3] = 2 * center - pt[1];
}
Eigen::Vector2f center;
Eigen::Vector2f size;
float angle;
};
template <class Derived>
cv::RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase<Derived>& box) {
RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase<Derived>& box) {
CAFFE_ENFORCE_EQ(box.size(), 5);
// cv::RotatedRect takes angle to mean clockwise rotation, but RRPN bbox
// representation means counter-clockwise rotation.
return cv::RotatedRect(
cv::Point2f(box[0], box[1]), cv::Size2f(box[2], box[3]), -box[4]);
return RotatedRect(
Eigen::Vector2f(box[0], box[1]),
Eigen::Vector2f(box[2], box[3]),
-box[4]);
}
// TODO: cvfix_rotatedRectangleIntersection is a replacement function for
// Eigen doesn't seem to support 2d cross product, so we make one here
float cross_2d(const Eigen::Vector2f& A, const Eigen::Vector2f& B) {
return A.x() * B.y() - B.x() * A.y();
}
// rotated_rect_intersection_pts is a replacement function for
// cv::rotatedRectangleIntersection, which has a bug due to float underflow
// When OpenCV version is upgraded to be >= 4.0,
// we can remove this replacement function.
// For anyone interested, here're the PRs on OpenCV:
// https://github.com/opencv/opencv/issues/12221
// https://github.com/opencv/opencv/pull/12222
int cvfix_rotatedRectangleIntersection(
const cv::RotatedRect& rect1,
const cv::RotatedRect& rect2,
cv::OutputArray intersectingRegion) {
// Note that we do not check if the number of intersections is <= 8 in this case
int rotated_rect_intersection_pts(
const RotatedRect& rect1,
const RotatedRect& rect2,
Eigen::Vector2f* intersections,
int& num) {
// Used to test if two points are the same
const float samePointEps = 0.00001f;
const float EPS = 1e-14;
num = 0; // number of intersections
cv::Point2f vec1[4], vec2[4];
cv::Point2f pts1[4], pts2[4];
Eigen::Vector2f vec1[4], vec2[4], pts1[4], pts2[4];
std::vector<cv::Point2f> intersection;
rect1.points(pts1);
rect2.points(pts2);
int ret = cv::INTERSECT_FULL;
rect1.get_vertices(pts1);
rect2.get_vertices(pts2);
// Specical case of rect1 == rect2
{
bool same = true;
bool same = true;
for (int i = 0; i < 4; i++) {
if (fabs(pts1[i].x() - pts2[i].x()) > samePointEps ||
(fabs(pts1[i].y() - pts2[i].y()) > samePointEps)) {
same = false;
break;
}
}
if (same) {
for (int i = 0; i < 4; i++) {
if (fabs(pts1[i].x - pts2[i].x) > samePointEps ||
(fabs(pts1[i].y - pts2[i].y) > samePointEps)) {
same = false;
break;
}
}
if (same) {
intersection.resize(4);
for (int i = 0; i < 4; i++) {
intersection[i] = pts1[i];
}
cv::Mat(intersection).copyTo(intersectingRegion);
return cv::INTERSECT_FULL;
intersections[i] = pts1[i];
}
num = 4;
return INTERSECT_FULL;
}
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
for (int i = 0; i < 4; i++) {
vec1[i].x = pts1[(i + 1) % 4].x - pts1[i].x;
vec1[i].y = pts1[(i + 1) % 4].y - pts1[i].y;
vec2[i].x = pts2[(i + 1) % 4].x - pts2[i].x;
vec2[i].y = pts2[(i + 1) % 4].y - pts2[i].y;
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// Line test - test all line combos for intersection
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
float x21 = pts2[j].x - pts1[i].x;
float y21 = pts2[j].y - pts1[i].y;
const auto& l1 = vec1[i];
const auto& l2 = vec2[j];
// This takes care of parallel lines
float det = l2.x * l1.y - l1.x * l2.y;
float det = cross_2d(vec2[j], vec1[i]);
if (std::fabs(det) <= EPS) {
continue;
}
float t1 = (l2.x * y21 - l2.y * x21) / det;
float t2 = (l1.x * y21 - l1.y * x21) / det;
auto vec12 = pts2[j] - pts1[i];
float t1 = cross_2d(vec2[j], vec12) / det;
float t2 = cross_2d(vec1[i], vec12) / det;
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
float xi = pts1[i].x + vec1[i].x * t1;
float yi = pts1[i].y + vec1[i].y * t1;
intersection.push_back(cv::Point2f(xi, yi));
intersections[num++] = pts1[i] + t1 * vec1[i];
}
}
}
if (!intersection.empty()) {
ret = cv::INTERSECT_PARTIAL;
}
// Check for vertices from rect1 inside rect2
for (int i = 0; i < 4; i++) {
// We do a sign test to see which side the point lies.
// If the point all lie on the same sign for all 4 sides of the rect,
// then there's an intersection
int posSign = 0;
int negSign = 0;
{
const auto& AB = vec2[0];
const auto& DA = vec2[3];
auto ABdotAB = AB.squaredNorm();
auto ADdotAD = DA.squaredNorm();
for (int i = 0; i < 4; i++) {
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
float x = pts1[i].x;
float y = pts1[i].y;
auto AP = pts1[i] - pts2[0];
for (int j = 0; j < 4; j++) {
// line equation: Ax + By + C = 0
// see which side of the line this point is at
auto APdotAB = AP.dot(AB);
auto APdotAD = -AP.dot(DA);
// float causes underflow!
// Original version:
// float A = -vec2[j].y;
// float B = vec2[j].x;
// float C = -(A * pts2[j].x + B * pts2[j].y);
// float s = A * x + B * y + C;
double A = -vec2[j].y;
double B = vec2[j].x;
double C = -(A * pts2[j].x + B * pts2[j].y);
double s = A * x + B * y + C;
if (s >= 0) {
posSign++;
} else {
negSign++;
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts1[i];
}
}
if (posSign == 4 || negSign == 4) {
intersection.push_back(pts1[i]);
}
}
// Reverse the check - check for vertices from rect2 inside rect1
for (int i = 0; i < 4; i++) {
// We do a sign test to see which side the point lies.
// If the point all lie on the same sign for all 4 sides of the rect,
// then there's an intersection
int posSign = 0;
int negSign = 0;
{
const auto& AB = vec1[0];
const auto& DA = vec1[3];
auto ABdotAB = AB.squaredNorm();
auto ADdotAD = DA.squaredNorm();
for (int i = 0; i < 4; i++) {
auto AP = pts2[i] - pts1[0];
float x = pts2[i].x;
float y = pts2[i].y;
auto APdotAB = AP.dot(AB);
auto APdotAD = -AP.dot(DA);
for (int j = 0; j < 4; j++) {
// line equation: Ax + By + C = 0
// see which side of the line this point is at
// float causes underflow!
// Original version:
// float A = -vec1[j].y;
// float B = vec1[j].x;
// float C = -(A * pts1[j].x + B * pts1[j].y);
// float s = A*x + B*y + C;
double A = -vec1[j].y;
double B = vec1[j].x;
double C = -(A * pts1[j].x + B * pts1[j].y);
double s = A * x + B * y + C;
if (s >= 0) {
posSign++;
} else {
negSign++;
}
}
if (posSign == 4 || negSign == 4) {
intersection.push_back(pts2[i]);
}
}
// Get rid of dupes
for (int i = 0; i < (int)intersection.size() - 1; i++) {
for (size_t j = i + 1; j < intersection.size(); j++) {
float dx = intersection[i].x - intersection[j].x;
float dy = intersection[i].y - intersection[j].y;
// can be a really small number, need double here
double d2 = dx * dx + dy * dy;
if (d2 < samePointEps * samePointEps) {
// Found a dupe, remove it
std::swap(intersection[j], intersection.back());
intersection.pop_back();
j--; // restart check
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts2[i];
}
}
}
if (intersection.empty()) {
return cv::INTERSECT_NONE;
}
return num ? INTERSECT_PARTIAL : INTERSECT_NONE;
}
// If this check fails then it means we're getting dupes
// CV_Assert(intersection.size() <= 8);
// Compute convex hull using Graham scan algorithm
int convex_hull_graham(
const Eigen::Vector2f* p,
const int& num_in,
Eigen::Vector2f* q,
bool shift_to_zero = false) {
CAFFE_ENFORCE(num_in >= 2);
std::vector<int> order;
// At this point, there might still be some edge cases failing the check above
// However, it doesn't affect the result of polygon area,
// even if the number of intersections is greater than 8.
// Therefore, we just print out these cases for now instead of assertion.
// TODO: These cases should provide good reference for improving the accuracy
// for intersection computation above (for example, we should use
// cross-product/dot-product of vectors instead of line equation to
// judge the relationships between the points and line segments)
if (intersection.size() > 8) {
LOG(ERROR) << "Intersection size = " << intersection.size();
LOG(ERROR) << "Rect 1:";
for (int i = 0; i < 4; i++) {
LOG(ERROR) << " (" << pts1[i].x << " ," << pts1[i].y << "),";
}
LOG(ERROR) << "Rect 2:";
for (int i = 0; i < 4; i++) {
LOG(ERROR) << " (" << pts2[i].x << " ," << pts2[i].y << "),";
}
LOG(ERROR) << "Intersections:";
for (auto& p : intersection) {
LOG(ERROR) << " (" << p.x << " ," << p.y << "),";
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the mimimum x.
int t = 0;
for (int i = 1; i < num_in; i++) {
if (p[i].y() < p[t].y() || (p[i].y() == p[t].y() && p[i].x() < p[t].x())) {
t = i;
}
}
auto& s = p[t]; // starting point
cv::Mat(intersection).copyTo(intersectingRegion);
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for (int i = 0; i < num_in; i++) {
q[i] = p[i] - s;
}
return ret;
// Swap the starting point to position 0
std::swap(q[0], q[t]);
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
std::sort(
q + 1,
q + num_in,
[](const Eigen::Vector2f& A, const Eigen::Vector2f& B) -> bool {
float temp = cross_2d(A, B);
if (fabs(temp) < 1e-6) {
return A.squaredNorm() < B.squaredNorm();
} else {
return temp > 0;
}
});
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int k; // index of the non-overlapped second point
for (k = 1; k < num_in; k++) {
if (q[k].squaredNorm() > 1e-8)
break;
}
if (k == num_in) {
// We reach the end, which means the convex hull is just one point
q[0] = p[t];
return 1;
}
q[1] = q[k];
int m = 2; // 2 elements in the stack
// Step 5:
// Finally we can start the scanning process.
// If we find a non-convex relationship between the 3 points,
// we pop the previous point from the stack until the stack only has two
// points, or the 3-point relationship is convex again
for (int i = k + 1; i < num_in; i++) {
while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
m--;
}
q[m++] = q[i];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if (!shift_to_zero) {
for (int i = 0; i < m; i++)
q[i] += s;
}
return m;
}
double polygon_area(const Eigen::Vector2f* q, const int& m) {
if (m <= 2)
return 0;
double area = 0;
for (int i = 1; i < m - 1; i++)
area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0]));
return area / 2.0;
}
/**
* Returns the intersection area of two rotated rectangles.
*/
double rotated_rect_intersection(
const cv::RotatedRect& rect1,
const cv::RotatedRect& rect2) {
std::vector<cv::Point2f> intersectPts, orderedPts;
const RotatedRect& rect1,
const RotatedRect& rect2) {
// There are up to 16 intersections returned from
// rotated_rect_intersection_pts
Eigen::Vector2f intersectPts[16], orderedPts[16];
int num = 0; // number of intersections
// Find points of intersection
// TODO: cvfix_rotatedRectangleIntersection is a replacement function for
// TODO: rotated_rect_intersection_pts is a replacement function for
// cv::rotatedRectangleIntersection, which has a bug due to float underflow
// When OpenCV version is upgraded to be >= 4.0,
// we can remove this replacement function and use the following instead:
// auto ret = cv::rotatedRectangleIntersection(rect1, rect2, intersectPts);
// For anyone interested, here're the PRs on OpenCV:
// https://github.com/opencv/opencv/issues/12221
// https://github.com/opencv/opencv/pull/12222
auto ret = cvfix_rotatedRectangleIntersection(rect1, rect2, intersectPts);
if (intersectPts.size() <= 2) {
// Note: it doesn't matter if #intersections is greater than 8 here
auto ret = rotated_rect_intersection_pts(rect1, rect2, intersectPts, num);
CAFFE_ENFORCE(num <= 16);
if (num <= 2)
return 0.0;
}
// If one rectangle is fully enclosed within another, return the area
// of the smaller one early.
if (ret == cv::INTERSECT_FULL) {
return std::min(rect1.size.area(), rect2.size.area());
if (ret == INTERSECT_FULL) {
return std::min(
rect1.size.x() * rect1.size.y(), rect2.size.x() * rect2.size.y());
}
// Convex Hull to order the intersection points in clockwise or
// counter-clockwise order and find the countour area.
cv::convexHull(intersectPts, orderedPts);
return cv::contourArea(orderedPts);
int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true);
return polygon_area(orderedPts, num_convex);
}
} // namespace
@ -507,7 +529,7 @@ std::vector<int> nms_cpu_rotated(
auto heights = proposals.col(3);
EArrX areas = widths * heights;
std::vector<cv::RotatedRect> rotated_rects(proposals.rows());
std::vector<RotatedRect> rotated_rects(proposals.rows());
for (int i = 0; i < proposals.rows(); ++i) {
rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i));
}
@ -568,7 +590,7 @@ std::vector<int> soft_nms_cpu_rotated(
auto heights = proposals.col(3);
EArrX areas = widths * heights;
std::vector<cv::RotatedRect> rotated_rects(proposals.rows());
std::vector<RotatedRect> rotated_rects(proposals.rows());
for (int i = 0; i < proposals.rows(); ++i) {
rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i));
}
@ -627,7 +649,6 @@ std::vector<int> soft_nms_cpu_rotated(
return keep;
}
#endif // CV_MAJOR_VERSION >= 3
template <class Derived1, class Derived2>
std::vector<int> nms_cpu(
@ -636,7 +657,6 @@ std::vector<int> nms_cpu(
const std::vector<int>& sorted_indices,
float thresh,
int topN = -1) {
#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5);
if (proposals.cols() == 4) {
// Upright boxes
@ -645,9 +665,6 @@ std::vector<int> nms_cpu(
// Rotated boxes with angle info
return nms_cpu_rotated(proposals, scores, sorted_indices, thresh, topN);
}
#else
return nms_cpu_upright(proposals, scores, sorted_indices, thresh, topN);
#endif // CV_MAJOR_VERSION >= 3
}
// Greedy non-maximum suppression for proposed bounding boxes
@ -686,7 +703,6 @@ std::vector<int> soft_nms_cpu(
float score_thresh = 0.001,
unsigned int method = 1,
int topN = -1) {
#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5);
if (proposals.cols() == 4) {
// Upright boxes
@ -713,18 +729,6 @@ std::vector<int> soft_nms_cpu(
method,
topN);
}
#else
return soft_nms_cpu_upright(
out_scores,
proposals,
scores,
indices,
sigma,
overlap_thresh,
score_thresh,
method,
topN);
#endif // CV_MAJOR_VERSION >= 3
}
template <class Derived1, class Derived2, class Derived3>

View File

@ -379,13 +379,9 @@ TEST(UtilsNMSTest, TestNMSGPURotatedAngle0) {
return;
const int box_dim = 5;
// Same boxes in TestNMS with (x_ctr, y_ctr, w, h, angle) format
std::vector<float> boxes = {
30, 35, 41, 51, 0,
29.5, 36, 38, 49, 0,
24, 29.5, 33, 42, 0,
125, 120, 51, 41, 0,
127, 124.5, 57, 30, 0
};
std::vector<float> boxes = {30, 35, 41, 51, 0, 29.5, 36, 38, 49,
0, 24, 29.5, 33, 42, 0, 125, 120, 51,
41, 0, 127, 124.5, 57, 30, 0};
std::vector<float> scores = {0.5f, 0.7f, 0.6f, 0.9f, 0.8f};
@ -466,7 +462,6 @@ TEST(UtilsNMSTest, TestNMSGPURotatedAngle0) {
cuda_context.FinishDeviceComputation();
}
#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
TEST(UtilsNMSTest, TestPerfRotatedNMS) {
if (!HasCudaGPU())
return;
@ -678,6 +673,5 @@ TEST(UtilsNMSTest, GPUEqualsCPURotatedCorrectnessTest) {
}
}
}
#endif // CV_MAJOR_VERSION >= 3
} // namespace caffe2

View File

@ -212,7 +212,6 @@ TEST(UtilsNMSTest, TestSoftNMS) {
}
}
#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
TEST(UtilsNMSTest, TestNMSRotatedAngle0) {
// Same inputs as TestNMS, but in RRPN format with angle 0 for testing
// nms_cpu_rotated
@ -388,6 +387,42 @@ TEST(UtilsNMSTest, TestSoftNMSRotatedAngle0) {
}
TEST(UtilsNMSTest, RotatedBBoxOverlaps) {
{
// One box is fully within another box, the angle is irrelavant
int M = 2, N = 3;
Eigen::ArrayXXf boxes(M, 5);
for (int i = 0; i < M; i++) {
boxes.row(i) << 0, 0, 5, 6, (360.0 / M - 180.0);
}
Eigen::ArrayXXf query_boxes(N, 5);
for (int i = 0; i < N; i++) {
query_boxes.row(i) << 0, 0, 3, 3, (360.0 / M - 180.0);
}
Eigen::ArrayXXf expected(M, N);
// 0.3 == (3 * 3) / (5 * 6)
expected.fill(0.3);
auto actual = utils::bbox_overlaps_rotated(boxes, query_boxes);
EXPECT_TRUE(((expected - actual).abs() < 1e-6).all());
}
{
// Angle 0
Eigen::ArrayXXf boxes(1, 5);
boxes << 39.500000, 50.451096, 80.000000, 18.097809, -0.000000;
Eigen::ArrayXXf query_boxes(1, 5);
query_boxes << 39.120628, 41.014862, 79.241257, 36.427757, -0.000000;
Eigen::ArrayXXf expected(1, 1);
expected << 0.48346716237;
auto actual = utils::bbox_overlaps_rotated(boxes, query_boxes);
EXPECT_TRUE(((expected - actual).abs() < 1e-6).all());
}
{
// Simple case with angle 0 (upright boxes)
Eigen::ArrayXXf boxes(2, 5);
@ -436,6 +471,5 @@ TEST(UtilsNMSTest, RotatedBBoxOverlaps) {
EXPECT_TRUE(((expected - actual).abs() < 1e-6).all());
}
}
#endif // CV_MAJOR_VERSION >= 3
} // namespace caffe2