[sr] remove max_indices argument of embedding_bag when unncessary (#75993)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75993

Strobelight shows copy_ in embedding_bag taking up a lot of time in adfinder_story_post_ad_session_exit_model 334827604_0
{F723683014}

More details in https://fb.quip.com/MKumAjz1YD4 (1f47a80e88)a#temp:C:FPD3 (ecd5567980)e5a0871ae5d481286b511ef7

The last 3 outputs of embedding_bag are unused in the graph: P495814049.
* max_indices output isn't necessary for the main output, so remove it when it's not used in the graph.
* offset2bag is used as an intermediate to calculate the main output, so we don't remove this output even though it's unused in the graph.
* bag_size is used as an intermediate to calculate the main output for MODE_MEAN, so we don't remove this for now.

Test Plan:
`./caffe2/caffe2/fb/predictor/scripts/run_disagg_model_benchmarks.sh 334827604 0 /data/users/ansha/tmp/ads_tail sr_only`

Inputs uploaded to `/mnt/persistent-public/ansha/ads_tail/334827604`

Before:
I0414 10:53:12.261133 1070948 PyTorchPredictorBenchLib.cpp:305] PyTorch run finished. Milliseconds per iter: 0.121318. Iters per second: 8242.78
        0.11156 ms.    99.0457%. aten::embedding_bag (52 nodes, out variant)

After:
I0418 13:05:10.837378 2354604 PyTorchPredictorBenchLib.cpp:305] PyTorch run finished. Milliseconds per iter: 0.0881273. Iters per second: 11347.2
      0.0789221 ms.    98.7096%. static_runtime::embedding_bag (52 nodes, out variant)

* Ads prod canary:
https://www.internalfb.com/intern/ads/canary/443002539593035806/
* 4M test: `servicelab create cogwheel_pyper_inference_fullsync_ads_inline_cvr_post_imp -a D35726594`
https://www.internalfb.com/intern/servicelab/602875732/
* 4M test: `servicelab create cogwheel_pyper_inference_fullsync_ads_10x_ctr_mbl_feed_non_mimo -a D35726594`
https://www.internalfb.com/intern/servicelab/1002874745/

Reviewed By: mikeiovine

Differential Revision: D35726594

fbshipit-source-id: 3b71a0822657bf7a23ce37ca899baef9997b011a
(cherry picked from commit fd5e3098c047a1e7d4348e1c97341eecb892536e)
This commit is contained in:
Ansha Yu
2022-04-22 08:27:11 -07:00
committed by PyTorch MergeBot
parent ecd5567980
commit ee636e2fd1
7 changed files with 258 additions and 77 deletions

View File

@ -947,7 +947,7 @@ static Tensor apply_bag_size_backward(
template <typename scalar_t>
void embedding_bag_cpu_max_out(
Tensor& max_indices,
Tensor* max_indices,
const Tensor& weight,
const Tensor& indices,
const Tensor& offset2bag,
@ -962,8 +962,12 @@ void embedding_bag_cpu_max_out(
auto* indices_data = indices.data_ptr<index_t>();
auto* offset2bag_data = offset2bag.data_ptr<index_t>();
auto* max_indices_data = max_indices.data_ptr<index_t>();
auto max_indices_stride = max_indices.strides()[0];
index_t* max_indices_data = nullptr;
int64_t max_indices_stride = 0;
if (max_indices) {
max_indices_data = max_indices->data_ptr<index_t>();
max_indices_stride = max_indices->strides()[0];
}
auto* weight_data = weight.data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();
@ -990,7 +994,9 @@ void embedding_bag_cpu_max_out(
if (is_first_for_bag || (weight_item > current_item)) {
current_item = weight_item;
max_indices_data[max_indices_stride * bag + dim] = word_idx;
if (max_indices_data) {
max_indices_data[max_indices_stride * bag + dim] = word_idx;
}
}
}
if (is_first_for_bag) {
@ -1005,7 +1011,7 @@ void embedding_bag_cpu_max_out(
}
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
Tensor& bag_size, Tensor& max_indices,
Tensor& bag_size, Tensor* max_indices,
const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const int64_t mode,
const c10::optional<Tensor>& per_sample_weights,
@ -1029,7 +1035,9 @@ void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
// make bag_size output deterministic
at::native::zero_(bag_size);
}
max_indices.copy_(bag_size);
if (max_indices) {
max_indices->copy_(bag_size);
}
} else { // MODE_MAX
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
weight.scalar_type(), "embedding_bag_cpu_max_out", [&]() {
@ -1067,7 +1075,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
Tensor max_indices = make_max_indices(weight, indices, offsets, bag_size, mode, include_last_offset);
_embedding_bag_cpu_impl_out(output, offset2bag,
bag_size, max_indices,
bag_size, &max_indices,
weight, indices, offsets,
mode, per_sample_weights,
include_last_offset, padding_idx);

View File

@ -39,7 +39,7 @@ void make_offset2bag_out(
const int64_t padding_idx = -1);
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
Tensor& bag_size, Tensor& max_indices,
Tensor& bag_size, Tensor* max_indices,
const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const int64_t mode = 0,
const c10::optional<Tensor>& per_sample_weights = c10::nullopt,

View File

@ -5,6 +5,8 @@
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/runtime/static/passes.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <stdexcept>
#include "deep_wide_pt.h"
@ -395,6 +397,99 @@ TEST(StaticRuntime, EmbeddingBagWithManagedOutput) {
testStaticRuntime(embedding_bag_managed_output, args, args2);
}
TEST(StaticRuntime, EmbeddingBagWithExtraneousOutput) {
const std::string embedding_bag_default_ir = R"IR(
graph(%weight, %indices, %offsets):
%scale_grad_by_freq : bool = prim::Constant[value=0]()
%mode : int = prim::Constant[value=0]()
%sparse : bool = prim::Constant[value=0]()
%per_sample_weights : NoneType = prim::Constant()
%include_last_offset : bool = prim::Constant[value=0]()
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
%none : NoneType = prim::Constant()
%res : Tensor = aten::clone(%y0, %none)
return (%res)
)IR";
auto graph = getGraphFromIR(embedding_bag_default_ir);
RemoveUnnecessaryOutputs(graph);
torch::jit::testing::FileCheck()
.check("static_runtime::embedding_bag")
->run(*graph);
const std::string embedding_bag_mean_ir = R"IR(
graph(%weight, %indices, %offsets):
%scale_grad_by_freq : bool = prim::Constant[value=0]()
%mode : int = prim::Constant[value=1]()
%sparse : bool = prim::Constant[value=0]()
%per_sample_weights : NoneType = prim::Constant()
%include_last_offset : bool = prim::Constant[value=0]()
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
%none : NoneType = prim::Constant()
%res : Tensor = aten::clone(%y0, %none)
return (%res)
)IR";
graph = getGraphFromIR(embedding_bag_mean_ir);
RemoveUnnecessaryOutputs(graph);
torch::jit::testing::FileCheck()
.check("static_runtime::embedding_bag")
->run(*graph);
const std::string embedding_bag_max_last_offset_ir = R"IR(
graph(%weight, %indices, %offsets):
%scale_grad_by_freq : bool = prim::Constant[value=0]()
%mode : int = prim::Constant[value=2]()
%sparse : bool = prim::Constant[value=0]()
%per_sample_weights : NoneType = prim::Constant()
%include_last_offset : bool = prim::Constant[value=1]()
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
%none : NoneType = prim::Constant()
%res : Tensor = aten::clone(%y0, %none)
return (%res)
)IR";
graph = getGraphFromIR(embedding_bag_max_last_offset_ir);
RemoveUnnecessaryOutputs(graph);
torch::jit::testing::FileCheck()
.check("static_runtime::embedding_bag")
->run(*graph);
const std::string embedding_bag_normal_ir = R"IR(
graph(%weight, %indices, %offsets):
%scale_grad_by_freq : bool = prim::Constant[value=0]()
%mode : int = prim::Constant[value=0]()
%sparse : bool = prim::Constant[value=0]()
%per_sample_weights : NoneType = prim::Constant()
%include_last_offset : bool = prim::Constant[value=0]()
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
%none : NoneType = prim::Constant()
%res0 : Tensor = aten::clone(%y0, %none)
%res1 : Tensor = aten::clone(%y1, %none)
%res2 : Tensor = aten::clone(%y2, %none)
%res3 : Tensor = aten::clone(%y3, %none)
return (%res0, %res1, %res2, %res3)
)IR";
graph = getGraphFromIR(embedding_bag_normal_ir);
RemoveUnnecessaryOutputs(graph);
torch::jit::testing::FileCheck()
.check_not("static_runtime::embedding_bag")
->run(*graph);
at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
at::Tensor input = torch::tensor({0, 1, 0, 2});
at::Tensor offset = torch::tensor({0, 2, 4});
std::vector<IValue> args{weight, input, offset};
testStaticRuntime(embedding_bag_default_ir, args);
testStaticRuntime(embedding_bag_mean_ir, args);
testStaticRuntime(embedding_bag_max_last_offset_ir, args);
at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
std::vector<IValue> args2{weight2, input2, offset2};
testStaticRuntime(embedding_bag_default_ir, args, args2);
testStaticRuntime(embedding_bag_mean_ir, args, args2);
testStaticRuntime(embedding_bag_max_last_offset_ir, args, args2);
}
TEST(StaticRuntime, LayerNorm) {
const std::string layer_norm_with_weights = R"JIT(
def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):

View File

@ -164,6 +164,7 @@ void OptimizeGraph(
ReplaceWithMaybeCopy(graph);
}
FuseListUnpack(graph);
RemoveUnnecessaryOutputs(graph);
#endif
}

View File

@ -1634,79 +1634,67 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
return nullptr;
});
REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -> SROperator {
// TODO: Support only 9 args once the old signature has been removed.
if (!n->matches(torch::schema(
"aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)")) &&
!n->matches(torch::schema(
"aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)"))) {
LogAndDumpSchema(n);
return nullptr;
}
return [](ProcessedNode* p_node) {
const auto& weight = p_node->Input(0).toTensor();
const auto& indices = p_node->Input(1).toTensor();
const auto& offsets = p_node->Input(2).toTensor();
auto scale_grad_by_freq = p_node->Input(3).toBool();
auto mode = p_node->Input(4).to<int64_t>();
auto sparse = p_node->Input(5).toBool();
auto per_sample_weights = p_node->Input(6).toOptional<at::Tensor>();
auto include_last_offset = p_node->Input(7).toBool();
c10::optional<int64_t> padding_idx;
if (p_node->num_inputs() == 9) {
if (p_node->Input(8).isNone()) {
padding_idx = c10::nullopt;
} else {
padding_idx = p_node->Input(8).toInt();
}
}
at::native::check_arguments(
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset);
std::ignore = scale_grad_by_freq;
std::ignore = sparse;
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::empty(
{include_last_offset ? offsets.sizes()[0] - 1 : offsets.sizes()[0],
weight.sizes()[1]},
weight.options());
namespace {
void prepare_embedding_bag(ProcessedNode* p_node, bool use_max_indices) {
const auto& weight = p_node->Input(0).toTensor();
const auto& indices = p_node->Input(1).toTensor();
const auto& offsets = p_node->Input(2).toTensor();
auto scale_grad_by_freq = p_node->Input(3).toBool();
auto mode = p_node->Input(4).to<int64_t>();
auto sparse = p_node->Input(5).toBool();
auto per_sample_weights = p_node->Input(6).toOptional<at::Tensor>();
auto include_last_offset = p_node->Input(7).toBool();
c10::optional<int64_t> padding_idx;
if (p_node->num_inputs() == 9) {
if (p_node->Input(8).isNone()) {
padding_idx = c10::nullopt;
} else {
at::native::resize_(
p_node->Output(0).toTensor(),
{include_last_offset ? offsets.sizes()[0] - 1 : offsets.sizes()[0],
weight.sizes()[1]},
c10::nullopt);
padding_idx = p_node->Input(8).toInt();
}
at::Tensor& output = p_node->Output(0).toTensor();
}
if (p_node->Output(1).isNone()) {
p_node->Output(1) = at::empty({0}, offsets.options());
}
at::Tensor& offset2bag = p_node->Output(1).toTensor();
at::native::make_offset2bag_out(
offset2bag,
output,
weight,
indices,
offsets,
mode,
per_sample_weights,
padding_idx.value_or(-1));
at::native::check_arguments(
weight, indices, offsets, mode, per_sample_weights, include_last_offset);
if (p_node->Output(2).isNone()) {
p_node->Output(2) = at::empty(offsets.sizes(), offsets.options());
}
at::Tensor& bag_size = p_node->Output(2).toTensor();
at::native::make_bag_size_out(
bag_size, offsets, indices, mode, include_last_offset, false);
std::ignore = scale_grad_by_freq;
std::ignore = sparse;
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::empty(
{include_last_offset ? offsets.sizes()[0] - 1 : offsets.sizes()[0],
weight.sizes()[1]},
weight.options());
} else {
at::native::resize_(
p_node->Output(0).toTensor(),
{include_last_offset ? offsets.sizes()[0] - 1 : offsets.sizes()[0],
weight.sizes()[1]},
c10::nullopt);
}
at::Tensor& output = p_node->Output(0).toTensor();
if (p_node->Output(1).isNone()) {
p_node->Output(1) = at::empty({0}, offsets.options());
}
at::Tensor& offset2bag = p_node->Output(1).toTensor();
at::native::make_offset2bag_out(
offset2bag,
output,
weight,
indices,
offsets,
mode,
per_sample_weights,
padding_idx.value_or(-1));
if (p_node->Output(2).isNone()) {
p_node->Output(2) = at::empty(offsets.sizes(), offsets.options());
}
at::Tensor& bag_size = p_node->Output(2).toTensor();
at::native::make_bag_size_out(
bag_size, offsets, indices, mode, include_last_offset, false);
if (use_max_indices) {
if (p_node->Output(3).isNone()) {
p_node->Output(3) = at::empty(bag_size.sizes(), offsets.options());
}
@ -1724,7 +1712,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -
output,
offset2bag,
bag_size,
max_indices,
&max_indices,
weight,
indices,
offsets,
@ -1732,9 +1720,52 @@ REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -
per_sample_weights,
include_last_offset,
padding_idx.value_or(-1));
} else {
at::native::_embedding_bag_cpu_impl_out(
output,
offset2bag,
bag_size,
nullptr,
weight,
indices,
offsets,
mode,
per_sample_weights,
include_last_offset,
padding_idx.value_or(-1));
}
}
} // namespace
REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)")) &&
!n->matches(torch::schema(
"aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)"))) {
LogAndDumpSchema(n);
return nullptr;
}
return [](ProcessedNode* p_node) {
prepare_embedding_bag(p_node, /*use_max_indices*/ true);
};
});
REGISTER_OPERATOR_FUNCTOR(
static_runtime::embedding_bag,
static_runtime_embedding_bag,
[](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"static_runtime::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor)")) &&
!n->matches(torch::schema(
"static_runtime::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor)"))) {
LogAndDumpSchema(n);
return nullptr;
}
return [](ProcessedNode* p_node) {
prepare_embedding_bag(p_node, /*use_max_indices*/ false);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::repeat, aten_repeat, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"aten::repeat(Tensor self, int[] repeats) -> Tensor"))) {

View File

@ -420,6 +420,12 @@ TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
m.def(torch::schema(
"static_runtime::create_owned_ref(...) -> ...",
c10::AliasAnalysisKind::CONSERVATIVE));
m.def(torch::schema(
"static_runtime::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor)",
c10::AliasAnalysisKind::PURE_FUNCTION));
m.def(torch::schema(
"static_runtime::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor)",
c10::AliasAnalysisKind::PURE_FUNCTION));
}
void FuseSignLog1P(std::shared_ptr<torch::jit::Graph>& graph) {
@ -1155,5 +1161,36 @@ void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph) {
}
}
C10_UNUSED void RemoveUnnecessaryOutputs(
std::shared_ptr<torch::jit::Graph>& graph) {
RemoveUnnecessaryEmbeddingBagOutputs(graph);
}
C10_UNUSED void RemoveUnnecessaryEmbeddingBagOutputs(
std::shared_ptr<torch::jit::Graph>& graph) {
std::string pattern = R"IR(
graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset):
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
return (%y2, %y1, %y0))IR";
std::string transformed_pattern = R"IR(
graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset):
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor = static_runtime::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
return (%y2, %y1, %y0))IR";
SubgraphRewriter fuse;
fuse.RegisterRewritePattern(pattern, transformed_pattern);
fuse.runOnGraph(graph);
std::string pattern2 = R"IR(
graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx):
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx)
return (%y2, %y1, %y0))IR";
std::string transformed_pattern2 = R"IR(
graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx):
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor = static_runtime::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx)
return (%y2, %y1, %y0))IR";
fuse.RegisterRewritePattern(pattern2, transformed_pattern2);
fuse.runOnGraph(graph);
}
} // namespace jit
} // namespace torch

View File

@ -65,5 +65,14 @@ TORCH_API void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph);
TORCH_API void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph);
// [Remove unnecessary outputs]]
// Removes outputs to reduce compute when it is not used later in the graph.
// Currently used to remove the max_indices output of embedding_bag, which
// isn't necessary to compute the main output.
TORCH_API void RemoveUnnecessaryOutputs(std::shared_ptr<Graph>& graph);
TORCH_API void RemoveUnnecessaryEmbeddingBagOutputs(
std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch