mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1/n][caffe2] Add session based margin loss function in caffe2 operator
Summary: Add session based margin loss into caffe2 operator. This is the first diff make these 2 loss available to dper3 Test Plan: unit test succeeds with gradient check for both new loss function buck test //caffe2/caffe2/python/operator_test:softmax_l2r_operator_test buck test //caffe2/caffe2/python/operator_test:margin_loss_l2r_operator_test E2E test in bento notebook with model training in N1488923 margin loss model: f318207967 f318207399 Notice that the E2E test is run with dper change in D33532976 to change a full model Reviewed By: devashisht Differential Revision: D32902460 fbshipit-source-id: 8f21b9109f500583431156908b632e503ed90dbd (cherry picked from commit 1592111aa4ed6cfdd7ca5f54de70396e9610757c)
This commit is contained in:
committed by
PyTorch MergeBot
parent
26d54b4076
commit
ae285d837e
163
caffe2/operators/margin_loss_l2r_op.cc
Normal file
163
caffe2/operators/margin_loss_l2r_op.cc
Normal file
@ -0,0 +1,163 @@
|
||||
#include "caffe2/operators/margin_loss_l2r_op.h"
|
||||
#include <cmath>
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/eigen_utils.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
||||
#define PAIRWISE_DIFF(vec, N) \
|
||||
((vec.matrix() * Eigen::MatrixXf::Ones(1, N) - \
|
||||
Eigen::MatrixXf::Ones(N, 1) * vec.matrix().transpose()) \
|
||||
.array())
|
||||
#define CWISE_GT(vec1, vec2) ((vec1) > (vec2))
|
||||
|
||||
#define CWISE_LT(vec1, vec2) ((vec1) < (vec2))
|
||||
|
||||
#define CWISE_SIGN(vec) \
|
||||
CWISE_GT((vec), 0).cast<float>() - CWISE_LT((vec), 0).cast<float>()
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
float SessionMarginLossOp<float, CPUContext>::SessionMarginLoss(
|
||||
int start_index,
|
||||
int end_index,
|
||||
const Tensor& pred,
|
||||
const Tensor& label,
|
||||
Tensor** dpred) {
|
||||
CAFFE_ENFORCE_LE(0.0, start_index);
|
||||
CAFFE_ENFORCE_GE(pred.numel(), start_index);
|
||||
const auto* pred_data = pred.template data<float>();
|
||||
const auto* label_data = label.template data<float>();
|
||||
int N = end_index - start_index + 1;
|
||||
ConstEigenVectorArrayMap<float> pred_vec(&pred_data[start_index], N);
|
||||
ConstEigenVectorArrayMap<float> label_vec(&label_data[start_index], N);
|
||||
auto* dpred_data = (*dpred)->template mutable_data<float>();
|
||||
EigenVectorArrayMap<float> dpred_vec(&dpred_data[start_index], N);
|
||||
dpred_vec = 0;
|
||||
|
||||
ReinitializeTensor(&margin_diff_, {N * N}, at::dtype<float>().device(CPU));
|
||||
auto* margin_diff_data = margin_diff_.template mutable_data<float>();
|
||||
EigenArrayMap<float> margin_diff_mat(margin_diff_data, N, N);
|
||||
|
||||
ReinitializeTensor(
|
||||
&label_relation_sign_, {N * N}, at::dtype<float>().device(CPU));
|
||||
auto* label_relation_sign_data =
|
||||
label_relation_sign_.template mutable_data<float>();
|
||||
EigenArrayMap<float> label_relation_sign_mat(label_relation_sign_data, N, N);
|
||||
|
||||
// in case that all docs in a session have zero ratings, no op
|
||||
if (label_vec.abs().sum() < 1e-6) {
|
||||
return 0;
|
||||
}
|
||||
if (N <= 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
float weight = 1.0f / N;
|
||||
|
||||
// define label relation, return N * N MATRIX, element (i, j) will be sign(label_i - label_i)
|
||||
label_relation_sign_mat = PAIRWISE_DIFF(label_vec, N).cwiseSign();
|
||||
margin_diff_mat =
|
||||
(margin_ - (label_relation_sign_mat * PAIRWISE_DIFF(pred_vec, N))) *
|
||||
label_relation_sign_mat.abs();
|
||||
float loss = 0.5f * weight *
|
||||
(margin_diff_mat * CWISE_GT(margin_diff_mat, 0).cast<float>()).sum();
|
||||
dpred_vec = -weight *
|
||||
((CWISE_GT(margin_diff_mat, 0).cast<float>()) * label_relation_sign_mat)
|
||||
.rowwise()
|
||||
.sum();
|
||||
|
||||
return loss;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool SessionMarginLossOp<float, CPUContext>::RunOnDevice() {
|
||||
auto& pred = Input(PRED);
|
||||
auto& label = Input(LABEL);
|
||||
auto& sid = Input(SESSION_LENS);
|
||||
|
||||
auto* dpred = Output(DPRED);
|
||||
|
||||
const auto* session_lengths = sid.template data<int>();
|
||||
CAFFE_ENFORCE(pred.dim() == 1);
|
||||
CAFFE_ENFORCE(pred.numel() == label.numel());
|
||||
dpred->Resize(pred.numel());
|
||||
auto* loss = Output(LOSS, {sid.numel()}, at::dtype<float>());
|
||||
auto loss_vec = loss->template mutable_data<float>();
|
||||
int start_id = 0;
|
||||
for (int i = 0; i < sid.numel(); i++) {
|
||||
loss_vec[i] = SessionMarginLoss(
|
||||
start_id, session_lengths[i] + start_id - 1, pred, label, &dpred);
|
||||
start_id += session_lengths[i];
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool SessionMarginLossGradientOp<float, CPUContext>::RunOnDevice() {
|
||||
auto& pred = Input(PRED);
|
||||
auto& sids = Input(SESSION_LENS);
|
||||
auto& precomputed_dpred = Input(PRECOMPUTED_DPRED);
|
||||
auto& dLoss = Input(DLOSS);
|
||||
|
||||
CAFFE_ENFORCE(pred.dim() == 1);
|
||||
CAFFE_ENFORCE(precomputed_dpred.dim() == 1);
|
||||
CAFFE_ENFORCE(precomputed_dpred.numel() > 0);
|
||||
CAFFE_ENFORCE(pred.numel() == precomputed_dpred.numel());
|
||||
|
||||
const auto* session_lengths = sids.template data<int>();
|
||||
CAFFE_ENFORCE(dLoss.numel() == sids.numel());
|
||||
|
||||
ConstEigenVectorArrayMap<float> precomputed_dpred_vec(
|
||||
precomputed_dpred.template data<float>(), precomputed_dpred.numel());
|
||||
auto* dpred = Output(DPRED, {precomputed_dpred.numel()}, at::dtype<float>());
|
||||
EigenVectorArrayMap<float> dpred_vec(
|
||||
dpred->template mutable_data<float>(), dpred->numel());
|
||||
auto multiplier = dLoss.template data<float>();
|
||||
int count = 0;
|
||||
for (int j = 0; j < sids.numel(); j++) {
|
||||
dpred_vec.segment(count, session_lengths[j]) = multiplier[j] *
|
||||
precomputed_dpred_vec.segment(count, session_lengths[j]);
|
||||
count += session_lengths[j];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
REGISTER_CPU_OPERATOR(
|
||||
SessionMarginLoss,
|
||||
SessionMarginLossOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(
|
||||
SessionMarginLossGradient,
|
||||
SessionMarginLossGradientOp<float, CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(SessionMarginLoss).NumInputs(3).NumOutputs(2).SetDoc(R"DOC(
|
||||
|
||||
This method optimizes the pairwise margin loss in a session with margin control.
|
||||
If multiple sessions are in a batch, pairwise loss will only be computed in a session and the total loss will be the sum of pairwise loss from each session.
|
||||
The exact loss function in a session is similar to https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html#torch.nn.MarginRankingLoss
|
||||
|
||||
)DOC");
|
||||
OPERATOR_SCHEMA(SessionMarginLossGradient).NumInputs(4).NumOutputs(1);
|
||||
|
||||
class GetSessionMarginLossGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
return SingleGradientDef(
|
||||
"SessionMarginLossGradient",
|
||||
"",
|
||||
vector<string>{I(0), I(2), O(1), GO(0)},
|
||||
vector<string>{GI(0)});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_GRADIENT(SessionMarginLoss, GetSessionMarginLossGradient);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace caffe2
|
51
caffe2/operators/margin_loss_l2r_op.h
Normal file
51
caffe2/operators/margin_loss_l2r_op.h
Normal file
@ -0,0 +1,51 @@
|
||||
// Copyright 2004-present Facebook. All Rights Reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T, class Context>
|
||||
class SessionMarginLossOp final : public Operator<Context> {
|
||||
public:
|
||||
template <class... Args>
|
||||
explicit SessionMarginLossOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
margin_(this->template GetSingleArgument<float>("margin", 1.0)) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
bool RunOnDevice() override;
|
||||
|
||||
private:
|
||||
INPUT_TAGS(PRED, LABEL, SESSION_LENS);
|
||||
OUTPUT_TAGS(LOSS, DPRED);
|
||||
|
||||
void ResizeInvLogITensor(int);
|
||||
void ComputeDiscounts(int*, int);
|
||||
float SessionMarginLoss(
|
||||
int start_index,
|
||||
int end_index,
|
||||
const Tensor& pred,
|
||||
const Tensor& label,
|
||||
Tensor** dpred);
|
||||
float margin_;
|
||||
Tensor label_relation_sign_;
|
||||
Tensor margin_diff_;
|
||||
};
|
||||
|
||||
template <typename T, class Context>
|
||||
class SessionMarginLossGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_SIMPLE_CTOR_DTOR(SessionMarginLossGradientOp)
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
bool RunOnDevice() override;
|
||||
|
||||
private:
|
||||
INPUT_TAGS(PRED, SESSION_LENS, PRECOMPUTED_DPRED, DLOSS);
|
||||
OUTPUT_TAGS(DPRED);
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
92
caffe2/python/operator_test/margin_loss_l2r_operator_test.py
Normal file
92
caffe2/python/operator_test/margin_loss_l2r_operator_test.py
Normal file
@ -0,0 +1,92 @@
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
from caffe2.python import core, workspace
|
||||
from hypothesis import given
|
||||
|
||||
|
||||
class TestMarginLossL2rOps(hu.HypothesisTestCase):
|
||||
def ref_margin_loss(self, y, r, margin):
|
||||
n = len(y)
|
||||
dy = np.zeros(n)
|
||||
loss = 0
|
||||
if np.sum(np.abs(r)) < 1e-6:
|
||||
return loss, dy
|
||||
|
||||
for i in range(n):
|
||||
for j in range(i + 1, n):
|
||||
weight = 1.0 / n
|
||||
diff = 1 if r[i] - r[j] > 0 else -1
|
||||
if (margin > (y[i] - y[j]) * diff) and (r[i] != r[j]):
|
||||
loss += weight * (margin - (y[i] - y[j]) * diff)
|
||||
dy[i] += -diff * weight
|
||||
dy[j] += diff * weight
|
||||
return loss, dy
|
||||
|
||||
@given(
|
||||
n=st.integers(10, 10),
|
||||
k=st.integers(2, 5),
|
||||
m=st.integers(1, 5),
|
||||
**hu.gcs_cpu_only
|
||||
)
|
||||
def test_session_margin_loss(self, n, k, m, gc, dc):
|
||||
y = np.random.rand(n * m).astype(np.float32)
|
||||
r = np.random.randint(k, size=n * m).astype(np.float32)
|
||||
# m sessions of length n
|
||||
session_lengths = np.repeat(n, m).astype(np.int32)
|
||||
ref_loss = np.empty(0)
|
||||
ref_scale_loss = np.empty(0)
|
||||
ref_dy = np.empty(0)
|
||||
ref_scale_dy = np.empty(0)
|
||||
for i in range(m):
|
||||
r_loss, r_dy = self.ref_margin_loss(
|
||||
y[(i) * n : (i + 1) * n], r[(i) * n : (i + 1) * n], 0.06
|
||||
)
|
||||
r_scale_loss, r_scale_dy = self.ref_margin_loss(
|
||||
y[(i) * n : (i + 1) * n], r[(i) * n : (i + 1) * n], 0.04
|
||||
)
|
||||
ref_loss = np.append(ref_loss, r_loss)
|
||||
ref_dy = np.append(ref_dy, r_dy)
|
||||
ref_scale_loss = np.append(ref_scale_loss, r_scale_loss)
|
||||
ref_scale_dy = np.append(ref_scale_dy, r_scale_dy)
|
||||
|
||||
dloss = np.random.random(m).astype(np.float32)
|
||||
|
||||
workspace.blobs["pred"] = y
|
||||
workspace.blobs["label"] = r
|
||||
workspace.blobs["session_lengths"] = session_lengths
|
||||
workspace.blobs["dloss"] = dloss
|
||||
|
||||
# Test scale = 1
|
||||
op = core.CreateOperator(
|
||||
"SessionMarginLoss",
|
||||
["pred", "label", "session_lengths"],
|
||||
["loss", "dpred"],
|
||||
margin=0.06,
|
||||
)
|
||||
workspace.RunOperatorOnce(op)
|
||||
loss = workspace.blobs["loss"]
|
||||
dy = workspace.blobs["dpred"]
|
||||
np.testing.assert_allclose(loss, ref_loss, rtol=1e-5, atol=1e-6)
|
||||
np.testing.assert_allclose(dy, ref_dy, rtol=1e-5, atol=1e-6)
|
||||
name = op.output[0]
|
||||
arr = workspace.FetchBlob(name)
|
||||
self.assertGradientChecks(
|
||||
gc, op, [y, r, session_lengths], 0, [0], stepsize=1e-3, threshold=2e-1
|
||||
)
|
||||
|
||||
# Test scale > 1
|
||||
op = core.CreateOperator(
|
||||
"SessionMarginLoss",
|
||||
["pred", "label", "session_lengths"],
|
||||
["loss", "dpred"],
|
||||
margin=0.04,
|
||||
)
|
||||
workspace.RunOperatorOnce(op)
|
||||
loss = workspace.blobs["loss"]
|
||||
dy = workspace.blobs["dpred"]
|
||||
np.testing.assert_allclose(loss, ref_scale_loss, rtol=1e-5, atol=1e-6)
|
||||
np.testing.assert_allclose(dy, ref_scale_dy, rtol=1e-5, atol=1e-6)
|
||||
self.assertGradientChecks(
|
||||
gc, op, [y, r, session_lengths], 0, [0], stepsize=1e-3, threshold=2e-1
|
||||
)
|
Reference in New Issue
Block a user