update legacy plus one for mpscnn

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20554

Reviewed By: jerryzh168

Differential Revision: D15362378

fbshipit-source-id: 070cd8314257386036dca89167c738c6602b3f33
This commit is contained in:
Yanghan Wang
2019-05-16 18:02:56 -07:00
committed by Facebook Github Bot
parent 8bdbd59d0c
commit 3c86d597c4

View File

@ -2072,7 +2072,9 @@ class MPSCNNGenerateProposalsCPPOp final : public Operator<CPUContext> {
OperatorBase::GetSingleArgument<int>("post_nms_topN", 300)),
rpn_nms_thresh_(
OperatorBase::GetSingleArgument<float>("nms_thresh", 0.7f)),
rpn_min_size_(OperatorBase::GetSingleArgument<float>("min_size", 16)) {}
rpn_min_size_(OperatorBase::GetSingleArgument<float>("min_size", 16)),
legacy_plus_one_(
this->template GetSingleArgument<bool>("legacy_plus_one", true)) {}
template <class Derived1, class Derived2>
std::vector<int> nms_metal(
@ -2207,14 +2209,21 @@ class MPSCNNGenerateProposalsCPPOp final : public Operator<CPUContext> {
Eigen::Map<ERMatXf>(scores.data(), H * W, A) =
Eigen::Map<const ERMatXf>(scores_tensor.data(), A, H * W).transpose();
// Transform anchors into proposals via bbox transformations
auto proposals = utils::bbox_transform(all_anchors.array(), bbox_deltas);
auto proposals = utils::bbox_transform(
all_anchors.array(),
bbox_deltas,
std::vector<float>{1.0, 1.0, 1.0, 1.0},
utils::BBOX_XFORM_CLIP_DEFAULT,
legacy_plus_one_);
// 2. clip proposals to image (may result in proposals with zero area
// that will be removed in the next step)
proposals = utils::clip_boxes(proposals, im_info[0], im_info[1]);
proposals = utils::clip_boxes(
proposals, im_info[0], im_info[1], 1.0, legacy_plus_one_);
// 3. remove predicted boxes with either height or width < min_size
auto keep = utils::filter_boxes(proposals, min_size, im_info);
auto keep =
utils::filter_boxes(proposals, min_size, im_info, legacy_plus_one_);
DCHECK_LE(keep.size(), scores.size());
@ -2334,6 +2343,8 @@ class MPSCNNGenerateProposalsCPPOp final : public Operator<CPUContext> {
float rpn_nms_thresh_{0.7};
// RPN_MIN_SIZE
float rpn_min_size_{16};
// The infamous "+ 1" for box width and height dating back to the DPM days
bool legacy_plus_one_{true};
// threads per thread group, used in nms
ushort maxThreadsPerThreadgroup{32};