[1/n][caffe2] Add session based margin loss function in caffe2 operator

Summary: Add session based margin loss into caffe2 operator. This is the first diff make these 2 loss available to dper3

Test Plan:
unit test succeeds with gradient check for both new loss function
buck test //caffe2/caffe2/python/operator_test:softmax_l2r_operator_test
buck test //caffe2/caffe2/python/operator_test:margin_loss_l2r_operator_test

E2E test in bento notebook with model training in N1488923
margin loss model: f318207967 f318207399

Notice that the E2E test is run with dper change in D33532976 to change a full model

Reviewed By: devashisht

Differential Revision: D32902460

fbshipit-source-id: 8f21b9109f500583431156908b632e503ed90dbd
(cherry picked from commit 1592111aa4ed6cfdd7ca5f54de70396e9610757c)
This commit is contained in:
Ziheng Huang
2022-01-21 15:08:41 -08:00
committed by PyTorch MergeBot
parent 26d54b4076
commit ae285d837e
3 changed files with 306 additions and 0 deletions

View File

@ -0,0 +1,163 @@
#include "caffe2/operators/margin_loss_l2r_op.h"
#include <cmath>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
namespace {
#define PAIRWISE_DIFF(vec, N) \
((vec.matrix() * Eigen::MatrixXf::Ones(1, N) - \
Eigen::MatrixXf::Ones(N, 1) * vec.matrix().transpose()) \
.array())
#define CWISE_GT(vec1, vec2) ((vec1) > (vec2))
#define CWISE_LT(vec1, vec2) ((vec1) < (vec2))
#define CWISE_SIGN(vec) \
CWISE_GT((vec), 0).cast<float>() - CWISE_LT((vec), 0).cast<float>()
} // namespace
template <>
float SessionMarginLossOp<float, CPUContext>::SessionMarginLoss(
int start_index,
int end_index,
const Tensor& pred,
const Tensor& label,
Tensor** dpred) {
CAFFE_ENFORCE_LE(0.0, start_index);
CAFFE_ENFORCE_GE(pred.numel(), start_index);
const auto* pred_data = pred.template data<float>();
const auto* label_data = label.template data<float>();
int N = end_index - start_index + 1;
ConstEigenVectorArrayMap<float> pred_vec(&pred_data[start_index], N);
ConstEigenVectorArrayMap<float> label_vec(&label_data[start_index], N);
auto* dpred_data = (*dpred)->template mutable_data<float>();
EigenVectorArrayMap<float> dpred_vec(&dpred_data[start_index], N);
dpred_vec = 0;
ReinitializeTensor(&margin_diff_, {N * N}, at::dtype<float>().device(CPU));
auto* margin_diff_data = margin_diff_.template mutable_data<float>();
EigenArrayMap<float> margin_diff_mat(margin_diff_data, N, N);
ReinitializeTensor(
&label_relation_sign_, {N * N}, at::dtype<float>().device(CPU));
auto* label_relation_sign_data =
label_relation_sign_.template mutable_data<float>();
EigenArrayMap<float> label_relation_sign_mat(label_relation_sign_data, N, N);
// in case that all docs in a session have zero ratings, no op
if (label_vec.abs().sum() < 1e-6) {
return 0;
}
if (N <= 0) {
return 0;
}
float weight = 1.0f / N;
// define label relation, return N * N MATRIX, element (i, j) will be sign(label_i - label_i)
label_relation_sign_mat = PAIRWISE_DIFF(label_vec, N).cwiseSign();
margin_diff_mat =
(margin_ - (label_relation_sign_mat * PAIRWISE_DIFF(pred_vec, N))) *
label_relation_sign_mat.abs();
float loss = 0.5f * weight *
(margin_diff_mat * CWISE_GT(margin_diff_mat, 0).cast<float>()).sum();
dpred_vec = -weight *
((CWISE_GT(margin_diff_mat, 0).cast<float>()) * label_relation_sign_mat)
.rowwise()
.sum();
return loss;
}
template <>
bool SessionMarginLossOp<float, CPUContext>::RunOnDevice() {
auto& pred = Input(PRED);
auto& label = Input(LABEL);
auto& sid = Input(SESSION_LENS);
auto* dpred = Output(DPRED);
const auto* session_lengths = sid.template data<int>();
CAFFE_ENFORCE(pred.dim() == 1);
CAFFE_ENFORCE(pred.numel() == label.numel());
dpred->Resize(pred.numel());
auto* loss = Output(LOSS, {sid.numel()}, at::dtype<float>());
auto loss_vec = loss->template mutable_data<float>();
int start_id = 0;
for (int i = 0; i < sid.numel(); i++) {
loss_vec[i] = SessionMarginLoss(
start_id, session_lengths[i] + start_id - 1, pred, label, &dpred);
start_id += session_lengths[i];
}
return true;
}
template <>
bool SessionMarginLossGradientOp<float, CPUContext>::RunOnDevice() {
auto& pred = Input(PRED);
auto& sids = Input(SESSION_LENS);
auto& precomputed_dpred = Input(PRECOMPUTED_DPRED);
auto& dLoss = Input(DLOSS);
CAFFE_ENFORCE(pred.dim() == 1);
CAFFE_ENFORCE(precomputed_dpred.dim() == 1);
CAFFE_ENFORCE(precomputed_dpred.numel() > 0);
CAFFE_ENFORCE(pred.numel() == precomputed_dpred.numel());
const auto* session_lengths = sids.template data<int>();
CAFFE_ENFORCE(dLoss.numel() == sids.numel());
ConstEigenVectorArrayMap<float> precomputed_dpred_vec(
precomputed_dpred.template data<float>(), precomputed_dpred.numel());
auto* dpred = Output(DPRED, {precomputed_dpred.numel()}, at::dtype<float>());
EigenVectorArrayMap<float> dpred_vec(
dpred->template mutable_data<float>(), dpred->numel());
auto multiplier = dLoss.template data<float>();
int count = 0;
for (int j = 0; j < sids.numel(); j++) {
dpred_vec.segment(count, session_lengths[j]) = multiplier[j] *
precomputed_dpred_vec.segment(count, session_lengths[j]);
count += session_lengths[j];
}
return true;
}
namespace {
REGISTER_CPU_OPERATOR(
SessionMarginLoss,
SessionMarginLossOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
SessionMarginLossGradient,
SessionMarginLossGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(SessionMarginLoss).NumInputs(3).NumOutputs(2).SetDoc(R"DOC(
This method optimizes the pairwise margin loss in a session with margin control.
If multiple sessions are in a batch, pairwise loss will only be computed in a session and the total loss will be the sum of pairwise loss from each session.
The exact loss function in a session is similar to https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html#torch.nn.MarginRankingLoss
)DOC");
OPERATOR_SCHEMA(SessionMarginLossGradient).NumInputs(4).NumOutputs(1);
class GetSessionMarginLossGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"SessionMarginLossGradient",
"",
vector<string>{I(0), I(2), O(1), GO(0)},
vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(SessionMarginLoss, GetSessionMarginLossGradient);
} // namespace
} // namespace caffe2

View File

@ -0,0 +1,51 @@
// Copyright 2004-present Facebook. All Rights Reserved.
#pragma once
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class SessionMarginLossOp final : public Operator<Context> {
public:
template <class... Args>
explicit SessionMarginLossOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
margin_(this->template GetSingleArgument<float>("margin", 1.0)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
private:
INPUT_TAGS(PRED, LABEL, SESSION_LENS);
OUTPUT_TAGS(LOSS, DPRED);
void ResizeInvLogITensor(int);
void ComputeDiscounts(int*, int);
float SessionMarginLoss(
int start_index,
int end_index,
const Tensor& pred,
const Tensor& label,
Tensor** dpred);
float margin_;
Tensor label_relation_sign_;
Tensor margin_diff_;
};
template <typename T, class Context>
class SessionMarginLossGradientOp final : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SessionMarginLossGradientOp)
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
private:
INPUT_TAGS(PRED, SESSION_LENS, PRECOMPUTED_DPRED, DLOSS);
OUTPUT_TAGS(DPRED);
};
} // namespace caffe2

View File

@ -0,0 +1,92 @@
import caffe2.python.hypothesis_test_util as hu
import hypothesis.strategies as st
import numpy as np
from caffe2.python import core, workspace
from hypothesis import given
class TestMarginLossL2rOps(hu.HypothesisTestCase):
def ref_margin_loss(self, y, r, margin):
n = len(y)
dy = np.zeros(n)
loss = 0
if np.sum(np.abs(r)) < 1e-6:
return loss, dy
for i in range(n):
for j in range(i + 1, n):
weight = 1.0 / n
diff = 1 if r[i] - r[j] > 0 else -1
if (margin > (y[i] - y[j]) * diff) and (r[i] != r[j]):
loss += weight * (margin - (y[i] - y[j]) * diff)
dy[i] += -diff * weight
dy[j] += diff * weight
return loss, dy
@given(
n=st.integers(10, 10),
k=st.integers(2, 5),
m=st.integers(1, 5),
**hu.gcs_cpu_only
)
def test_session_margin_loss(self, n, k, m, gc, dc):
y = np.random.rand(n * m).astype(np.float32)
r = np.random.randint(k, size=n * m).astype(np.float32)
# m sessions of length n
session_lengths = np.repeat(n, m).astype(np.int32)
ref_loss = np.empty(0)
ref_scale_loss = np.empty(0)
ref_dy = np.empty(0)
ref_scale_dy = np.empty(0)
for i in range(m):
r_loss, r_dy = self.ref_margin_loss(
y[(i) * n : (i + 1) * n], r[(i) * n : (i + 1) * n], 0.06
)
r_scale_loss, r_scale_dy = self.ref_margin_loss(
y[(i) * n : (i + 1) * n], r[(i) * n : (i + 1) * n], 0.04
)
ref_loss = np.append(ref_loss, r_loss)
ref_dy = np.append(ref_dy, r_dy)
ref_scale_loss = np.append(ref_scale_loss, r_scale_loss)
ref_scale_dy = np.append(ref_scale_dy, r_scale_dy)
dloss = np.random.random(m).astype(np.float32)
workspace.blobs["pred"] = y
workspace.blobs["label"] = r
workspace.blobs["session_lengths"] = session_lengths
workspace.blobs["dloss"] = dloss
# Test scale = 1
op = core.CreateOperator(
"SessionMarginLoss",
["pred", "label", "session_lengths"],
["loss", "dpred"],
margin=0.06,
)
workspace.RunOperatorOnce(op)
loss = workspace.blobs["loss"]
dy = workspace.blobs["dpred"]
np.testing.assert_allclose(loss, ref_loss, rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(dy, ref_dy, rtol=1e-5, atol=1e-6)
name = op.output[0]
arr = workspace.FetchBlob(name)
self.assertGradientChecks(
gc, op, [y, r, session_lengths], 0, [0], stepsize=1e-3, threshold=2e-1
)
# Test scale > 1
op = core.CreateOperator(
"SessionMarginLoss",
["pred", "label", "session_lengths"],
["loss", "dpred"],
margin=0.04,
)
workspace.RunOperatorOnce(op)
loss = workspace.blobs["loss"]
dy = workspace.blobs["dpred"]
np.testing.assert_allclose(loss, ref_scale_loss, rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(dy, ref_scale_dy, rtol=1e-5, atol=1e-6)
self.assertGradientChecks(
gc, op, [y, r, session_lengths], 0, [0], stepsize=1e-3, threshold=2e-1
)