mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sr] remove max_indices argument of embedding_bag when unncessary (#75993)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75993 Strobelight shows copy_ in embedding_bag taking up a lot of time in adfinder_story_post_ad_session_exit_model 334827604_0 {F723683014} More details in https://fb.quip.com/MKumAjz1YD4 (1f47a80e88
)a#temp:C:FPD3 (ecd5567980
)e5a0871ae5d481286b511ef7 The last 3 outputs of embedding_bag are unused in the graph: P495814049. * max_indices output isn't necessary for the main output, so remove it when it's not used in the graph. * offset2bag is used as an intermediate to calculate the main output, so we don't remove this output even though it's unused in the graph. * bag_size is used as an intermediate to calculate the main output for MODE_MEAN, so we don't remove this for now. Test Plan: `./caffe2/caffe2/fb/predictor/scripts/run_disagg_model_benchmarks.sh 334827604 0 /data/users/ansha/tmp/ads_tail sr_only` Inputs uploaded to `/mnt/persistent-public/ansha/ads_tail/334827604` Before: I0414 10:53:12.261133 1070948 PyTorchPredictorBenchLib.cpp:305] PyTorch run finished. Milliseconds per iter: 0.121318. Iters per second: 8242.78 0.11156 ms. 99.0457%. aten::embedding_bag (52 nodes, out variant) After: I0418 13:05:10.837378 2354604 PyTorchPredictorBenchLib.cpp:305] PyTorch run finished. Milliseconds per iter: 0.0881273. Iters per second: 11347.2 0.0789221 ms. 98.7096%. static_runtime::embedding_bag (52 nodes, out variant) * Ads prod canary: https://www.internalfb.com/intern/ads/canary/443002539593035806/ * 4M test: `servicelab create cogwheel_pyper_inference_fullsync_ads_inline_cvr_post_imp -a D35726594` https://www.internalfb.com/intern/servicelab/602875732/ * 4M test: `servicelab create cogwheel_pyper_inference_fullsync_ads_10x_ctr_mbl_feed_non_mimo -a D35726594` https://www.internalfb.com/intern/servicelab/1002874745/ Reviewed By: mikeiovine Differential Revision: D35726594 fbshipit-source-id: 3b71a0822657bf7a23ce37ca899baef9997b011a (cherry picked from commit fd5e3098c047a1e7d4348e1c97341eecb892536e)
This commit is contained in:
committed by
PyTorch MergeBot
parent
ecd5567980
commit
ee636e2fd1
@ -947,7 +947,7 @@ static Tensor apply_bag_size_backward(
|
||||
|
||||
template <typename scalar_t>
|
||||
void embedding_bag_cpu_max_out(
|
||||
Tensor& max_indices,
|
||||
Tensor* max_indices,
|
||||
const Tensor& weight,
|
||||
const Tensor& indices,
|
||||
const Tensor& offset2bag,
|
||||
@ -962,8 +962,12 @@ void embedding_bag_cpu_max_out(
|
||||
auto* indices_data = indices.data_ptr<index_t>();
|
||||
auto* offset2bag_data = offset2bag.data_ptr<index_t>();
|
||||
|
||||
auto* max_indices_data = max_indices.data_ptr<index_t>();
|
||||
auto max_indices_stride = max_indices.strides()[0];
|
||||
index_t* max_indices_data = nullptr;
|
||||
int64_t max_indices_stride = 0;
|
||||
if (max_indices) {
|
||||
max_indices_data = max_indices->data_ptr<index_t>();
|
||||
max_indices_stride = max_indices->strides()[0];
|
||||
}
|
||||
|
||||
auto* weight_data = weight.data_ptr<scalar_t>();
|
||||
auto* output_data = output.data_ptr<scalar_t>();
|
||||
@ -990,7 +994,9 @@ void embedding_bag_cpu_max_out(
|
||||
|
||||
if (is_first_for_bag || (weight_item > current_item)) {
|
||||
current_item = weight_item;
|
||||
max_indices_data[max_indices_stride * bag + dim] = word_idx;
|
||||
if (max_indices_data) {
|
||||
max_indices_data[max_indices_stride * bag + dim] = word_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_first_for_bag) {
|
||||
@ -1005,7 +1011,7 @@ void embedding_bag_cpu_max_out(
|
||||
}
|
||||
|
||||
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
|
||||
Tensor& bag_size, Tensor& max_indices,
|
||||
Tensor& bag_size, Tensor* max_indices,
|
||||
const Tensor &weight, const Tensor &indices,
|
||||
const Tensor &offsets, const int64_t mode,
|
||||
const c10::optional<Tensor>& per_sample_weights,
|
||||
@ -1029,7 +1035,9 @@ void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
|
||||
// make bag_size output deterministic
|
||||
at::native::zero_(bag_size);
|
||||
}
|
||||
max_indices.copy_(bag_size);
|
||||
if (max_indices) {
|
||||
max_indices->copy_(bag_size);
|
||||
}
|
||||
} else { // MODE_MAX
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
weight.scalar_type(), "embedding_bag_cpu_max_out", [&]() {
|
||||
@ -1067,7 +1075,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
|
||||
Tensor max_indices = make_max_indices(weight, indices, offsets, bag_size, mode, include_last_offset);
|
||||
|
||||
_embedding_bag_cpu_impl_out(output, offset2bag,
|
||||
bag_size, max_indices,
|
||||
bag_size, &max_indices,
|
||||
weight, indices, offsets,
|
||||
mode, per_sample_weights,
|
||||
include_last_offset, padding_idx);
|
||||
|
@ -39,7 +39,7 @@ void make_offset2bag_out(
|
||||
const int64_t padding_idx = -1);
|
||||
|
||||
void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
|
||||
Tensor& bag_size, Tensor& max_indices,
|
||||
Tensor& bag_size, Tensor* max_indices,
|
||||
const Tensor &weight, const Tensor &indices,
|
||||
const Tensor &offsets, const int64_t mode = 0,
|
||||
const c10::optional<Tensor>& per_sample_weights = c10::nullopt,
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include <torch/csrc/jit/ir/irparser.h>
|
||||
#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
#include <torch/csrc/jit/runtime/static/passes.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "deep_wide_pt.h"
|
||||
@ -395,6 +397,99 @@ TEST(StaticRuntime, EmbeddingBagWithManagedOutput) {
|
||||
testStaticRuntime(embedding_bag_managed_output, args, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, EmbeddingBagWithExtraneousOutput) {
|
||||
const std::string embedding_bag_default_ir = R"IR(
|
||||
graph(%weight, %indices, %offsets):
|
||||
%scale_grad_by_freq : bool = prim::Constant[value=0]()
|
||||
%mode : int = prim::Constant[value=0]()
|
||||
%sparse : bool = prim::Constant[value=0]()
|
||||
%per_sample_weights : NoneType = prim::Constant()
|
||||
%include_last_offset : bool = prim::Constant[value=0]()
|
||||
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
|
||||
%none : NoneType = prim::Constant()
|
||||
%res : Tensor = aten::clone(%y0, %none)
|
||||
return (%res)
|
||||
)IR";
|
||||
auto graph = getGraphFromIR(embedding_bag_default_ir);
|
||||
RemoveUnnecessaryOutputs(graph);
|
||||
torch::jit::testing::FileCheck()
|
||||
.check("static_runtime::embedding_bag")
|
||||
->run(*graph);
|
||||
|
||||
const std::string embedding_bag_mean_ir = R"IR(
|
||||
graph(%weight, %indices, %offsets):
|
||||
%scale_grad_by_freq : bool = prim::Constant[value=0]()
|
||||
%mode : int = prim::Constant[value=1]()
|
||||
%sparse : bool = prim::Constant[value=0]()
|
||||
%per_sample_weights : NoneType = prim::Constant()
|
||||
%include_last_offset : bool = prim::Constant[value=0]()
|
||||
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
|
||||
%none : NoneType = prim::Constant()
|
||||
%res : Tensor = aten::clone(%y0, %none)
|
||||
return (%res)
|
||||
)IR";
|
||||
graph = getGraphFromIR(embedding_bag_mean_ir);
|
||||
RemoveUnnecessaryOutputs(graph);
|
||||
torch::jit::testing::FileCheck()
|
||||
.check("static_runtime::embedding_bag")
|
||||
->run(*graph);
|
||||
|
||||
const std::string embedding_bag_max_last_offset_ir = R"IR(
|
||||
graph(%weight, %indices, %offsets):
|
||||
%scale_grad_by_freq : bool = prim::Constant[value=0]()
|
||||
%mode : int = prim::Constant[value=2]()
|
||||
%sparse : bool = prim::Constant[value=0]()
|
||||
%per_sample_weights : NoneType = prim::Constant()
|
||||
%include_last_offset : bool = prim::Constant[value=1]()
|
||||
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
|
||||
%none : NoneType = prim::Constant()
|
||||
%res : Tensor = aten::clone(%y0, %none)
|
||||
return (%res)
|
||||
)IR";
|
||||
graph = getGraphFromIR(embedding_bag_max_last_offset_ir);
|
||||
RemoveUnnecessaryOutputs(graph);
|
||||
torch::jit::testing::FileCheck()
|
||||
.check("static_runtime::embedding_bag")
|
||||
->run(*graph);
|
||||
|
||||
const std::string embedding_bag_normal_ir = R"IR(
|
||||
graph(%weight, %indices, %offsets):
|
||||
%scale_grad_by_freq : bool = prim::Constant[value=0]()
|
||||
%mode : int = prim::Constant[value=0]()
|
||||
%sparse : bool = prim::Constant[value=0]()
|
||||
%per_sample_weights : NoneType = prim::Constant()
|
||||
%include_last_offset : bool = prim::Constant[value=0]()
|
||||
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
|
||||
%none : NoneType = prim::Constant()
|
||||
%res0 : Tensor = aten::clone(%y0, %none)
|
||||
%res1 : Tensor = aten::clone(%y1, %none)
|
||||
%res2 : Tensor = aten::clone(%y2, %none)
|
||||
%res3 : Tensor = aten::clone(%y3, %none)
|
||||
return (%res0, %res1, %res2, %res3)
|
||||
)IR";
|
||||
graph = getGraphFromIR(embedding_bag_normal_ir);
|
||||
RemoveUnnecessaryOutputs(graph);
|
||||
torch::jit::testing::FileCheck()
|
||||
.check_not("static_runtime::embedding_bag")
|
||||
->run(*graph);
|
||||
|
||||
at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
|
||||
at::Tensor input = torch::tensor({0, 1, 0, 2});
|
||||
at::Tensor offset = torch::tensor({0, 2, 4});
|
||||
std::vector<IValue> args{weight, input, offset};
|
||||
testStaticRuntime(embedding_bag_default_ir, args);
|
||||
testStaticRuntime(embedding_bag_mean_ir, args);
|
||||
testStaticRuntime(embedding_bag_max_last_offset_ir, args);
|
||||
|
||||
at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
|
||||
at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
|
||||
at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
|
||||
std::vector<IValue> args2{weight2, input2, offset2};
|
||||
testStaticRuntime(embedding_bag_default_ir, args, args2);
|
||||
testStaticRuntime(embedding_bag_mean_ir, args, args2);
|
||||
testStaticRuntime(embedding_bag_max_last_offset_ir, args, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, LayerNorm) {
|
||||
const std::string layer_norm_with_weights = R"JIT(
|
||||
def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
|
||||
|
@ -164,6 +164,7 @@ void OptimizeGraph(
|
||||
ReplaceWithMaybeCopy(graph);
|
||||
}
|
||||
FuseListUnpack(graph);
|
||||
RemoveUnnecessaryOutputs(graph);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -1634,79 +1634,67 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -> SROperator {
|
||||
// TODO: Support only 9 args once the old signature has been removed.
|
||||
if (!n->matches(torch::schema(
|
||||
"aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)")) &&
|
||||
!n->matches(torch::schema(
|
||||
"aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)"))) {
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
}
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& weight = p_node->Input(0).toTensor();
|
||||
const auto& indices = p_node->Input(1).toTensor();
|
||||
const auto& offsets = p_node->Input(2).toTensor();
|
||||
auto scale_grad_by_freq = p_node->Input(3).toBool();
|
||||
auto mode = p_node->Input(4).to<int64_t>();
|
||||
auto sparse = p_node->Input(5).toBool();
|
||||
auto per_sample_weights = p_node->Input(6).toOptional<at::Tensor>();
|
||||
auto include_last_offset = p_node->Input(7).toBool();
|
||||
c10::optional<int64_t> padding_idx;
|
||||
if (p_node->num_inputs() == 9) {
|
||||
if (p_node->Input(8).isNone()) {
|
||||
padding_idx = c10::nullopt;
|
||||
} else {
|
||||
padding_idx = p_node->Input(8).toInt();
|
||||
}
|
||||
}
|
||||
|
||||
at::native::check_arguments(
|
||||
weight,
|
||||
indices,
|
||||
offsets,
|
||||
mode,
|
||||
per_sample_weights,
|
||||
include_last_offset);
|
||||
|
||||
std::ignore = scale_grad_by_freq;
|
||||
std::ignore = sparse;
|
||||
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = at::empty(
|
||||
{include_last_offset ? offsets.sizes()[0] - 1 : offsets.sizes()[0],
|
||||
weight.sizes()[1]},
|
||||
weight.options());
|
||||
namespace {
|
||||
void prepare_embedding_bag(ProcessedNode* p_node, bool use_max_indices) {
|
||||
const auto& weight = p_node->Input(0).toTensor();
|
||||
const auto& indices = p_node->Input(1).toTensor();
|
||||
const auto& offsets = p_node->Input(2).toTensor();
|
||||
auto scale_grad_by_freq = p_node->Input(3).toBool();
|
||||
auto mode = p_node->Input(4).to<int64_t>();
|
||||
auto sparse = p_node->Input(5).toBool();
|
||||
auto per_sample_weights = p_node->Input(6).toOptional<at::Tensor>();
|
||||
auto include_last_offset = p_node->Input(7).toBool();
|
||||
c10::optional<int64_t> padding_idx;
|
||||
if (p_node->num_inputs() == 9) {
|
||||
if (p_node->Input(8).isNone()) {
|
||||
padding_idx = c10::nullopt;
|
||||
} else {
|
||||
at::native::resize_(
|
||||
p_node->Output(0).toTensor(),
|
||||
{include_last_offset ? offsets.sizes()[0] - 1 : offsets.sizes()[0],
|
||||
weight.sizes()[1]},
|
||||
c10::nullopt);
|
||||
padding_idx = p_node->Input(8).toInt();
|
||||
}
|
||||
at::Tensor& output = p_node->Output(0).toTensor();
|
||||
}
|
||||
|
||||
if (p_node->Output(1).isNone()) {
|
||||
p_node->Output(1) = at::empty({0}, offsets.options());
|
||||
}
|
||||
at::Tensor& offset2bag = p_node->Output(1).toTensor();
|
||||
at::native::make_offset2bag_out(
|
||||
offset2bag,
|
||||
output,
|
||||
weight,
|
||||
indices,
|
||||
offsets,
|
||||
mode,
|
||||
per_sample_weights,
|
||||
padding_idx.value_or(-1));
|
||||
at::native::check_arguments(
|
||||
weight, indices, offsets, mode, per_sample_weights, include_last_offset);
|
||||
|
||||
if (p_node->Output(2).isNone()) {
|
||||
p_node->Output(2) = at::empty(offsets.sizes(), offsets.options());
|
||||
}
|
||||
at::Tensor& bag_size = p_node->Output(2).toTensor();
|
||||
at::native::make_bag_size_out(
|
||||
bag_size, offsets, indices, mode, include_last_offset, false);
|
||||
std::ignore = scale_grad_by_freq;
|
||||
std::ignore = sparse;
|
||||
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = at::empty(
|
||||
{include_last_offset ? offsets.sizes()[0] - 1 : offsets.sizes()[0],
|
||||
weight.sizes()[1]},
|
||||
weight.options());
|
||||
} else {
|
||||
at::native::resize_(
|
||||
p_node->Output(0).toTensor(),
|
||||
{include_last_offset ? offsets.sizes()[0] - 1 : offsets.sizes()[0],
|
||||
weight.sizes()[1]},
|
||||
c10::nullopt);
|
||||
}
|
||||
at::Tensor& output = p_node->Output(0).toTensor();
|
||||
|
||||
if (p_node->Output(1).isNone()) {
|
||||
p_node->Output(1) = at::empty({0}, offsets.options());
|
||||
}
|
||||
at::Tensor& offset2bag = p_node->Output(1).toTensor();
|
||||
at::native::make_offset2bag_out(
|
||||
offset2bag,
|
||||
output,
|
||||
weight,
|
||||
indices,
|
||||
offsets,
|
||||
mode,
|
||||
per_sample_weights,
|
||||
padding_idx.value_or(-1));
|
||||
|
||||
if (p_node->Output(2).isNone()) {
|
||||
p_node->Output(2) = at::empty(offsets.sizes(), offsets.options());
|
||||
}
|
||||
at::Tensor& bag_size = p_node->Output(2).toTensor();
|
||||
at::native::make_bag_size_out(
|
||||
bag_size, offsets, indices, mode, include_last_offset, false);
|
||||
|
||||
if (use_max_indices) {
|
||||
if (p_node->Output(3).isNone()) {
|
||||
p_node->Output(3) = at::empty(bag_size.sizes(), offsets.options());
|
||||
}
|
||||
@ -1724,7 +1712,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -
|
||||
output,
|
||||
offset2bag,
|
||||
bag_size,
|
||||
max_indices,
|
||||
&max_indices,
|
||||
weight,
|
||||
indices,
|
||||
offsets,
|
||||
@ -1732,9 +1720,52 @@ REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx.value_or(-1));
|
||||
} else {
|
||||
at::native::_embedding_bag_cpu_impl_out(
|
||||
output,
|
||||
offset2bag,
|
||||
bag_size,
|
||||
nullptr,
|
||||
weight,
|
||||
indices,
|
||||
offsets,
|
||||
mode,
|
||||
per_sample_weights,
|
||||
include_last_offset,
|
||||
padding_idx.value_or(-1));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REGISTER_OPERATOR_FUNCTOR(aten::embedding_bag, aten_embedding_bag, [](Node* n) -> SROperator {
|
||||
if (!n->matches(torch::schema(
|
||||
"aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)")) &&
|
||||
!n->matches(torch::schema(
|
||||
"aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)"))) {
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
}
|
||||
return [](ProcessedNode* p_node) {
|
||||
prepare_embedding_bag(p_node, /*use_max_indices*/ true);
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_OPERATOR_FUNCTOR(
|
||||
static_runtime::embedding_bag,
|
||||
static_runtime_embedding_bag,
|
||||
[](Node* n) -> SROperator {
|
||||
if (!n->matches(torch::schema(
|
||||
"static_runtime::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor)")) &&
|
||||
!n->matches(torch::schema(
|
||||
"static_runtime::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor)"))) {
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
}
|
||||
return [](ProcessedNode* p_node) {
|
||||
prepare_embedding_bag(p_node, /*use_max_indices*/ false);
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_OPERATOR_FUNCTOR(aten::repeat, aten_repeat, [](Node* n) -> SROperator {
|
||||
if (!n->matches(torch::schema(
|
||||
"aten::repeat(Tensor self, int[] repeats) -> Tensor"))) {
|
||||
|
@ -420,6 +420,12 @@ TORCH_LIBRARY_FRAGMENT(static_runtime, m) {
|
||||
m.def(torch::schema(
|
||||
"static_runtime::create_owned_ref(...) -> ...",
|
||||
c10::AliasAnalysisKind::CONSERVATIVE));
|
||||
m.def(torch::schema(
|
||||
"static_runtime::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor)",
|
||||
c10::AliasAnalysisKind::PURE_FUNCTION));
|
||||
m.def(torch::schema(
|
||||
"static_runtime::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor)",
|
||||
c10::AliasAnalysisKind::PURE_FUNCTION));
|
||||
}
|
||||
|
||||
void FuseSignLog1P(std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
@ -1155,5 +1161,36 @@ void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph) {
|
||||
}
|
||||
}
|
||||
|
||||
C10_UNUSED void RemoveUnnecessaryOutputs(
|
||||
std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
RemoveUnnecessaryEmbeddingBagOutputs(graph);
|
||||
}
|
||||
|
||||
C10_UNUSED void RemoveUnnecessaryEmbeddingBagOutputs(
|
||||
std::shared_ptr<torch::jit::Graph>& graph) {
|
||||
std::string pattern = R"IR(
|
||||
graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset):
|
||||
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
|
||||
return (%y2, %y1, %y0))IR";
|
||||
std::string transformed_pattern = R"IR(
|
||||
graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset):
|
||||
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor = static_runtime::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
|
||||
return (%y2, %y1, %y0))IR";
|
||||
SubgraphRewriter fuse;
|
||||
fuse.RegisterRewritePattern(pattern, transformed_pattern);
|
||||
fuse.runOnGraph(graph);
|
||||
|
||||
std::string pattern2 = R"IR(
|
||||
graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx):
|
||||
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx)
|
||||
return (%y2, %y1, %y0))IR";
|
||||
std::string transformed_pattern2 = R"IR(
|
||||
graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx):
|
||||
%y0 : Tensor, %y1 : Tensor, %y2 : Tensor = static_runtime::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset, %padding_idx)
|
||||
return (%y2, %y1, %y0))IR";
|
||||
fuse.RegisterRewritePattern(pattern2, transformed_pattern2);
|
||||
fuse.runOnGraph(graph);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -65,5 +65,14 @@ TORCH_API void EliminateExtraPermuteOps(std::shared_ptr<Graph>& graph);
|
||||
|
||||
TORCH_API void UseSplitAndSqueeze(std::shared_ptr<Graph>& graph);
|
||||
|
||||
// [Remove unnecessary outputs]]
|
||||
// Removes outputs to reduce compute when it is not used later in the graph.
|
||||
// Currently used to remove the max_indices output of embedding_bag, which
|
||||
// isn't necessary to compute the main output.
|
||||
TORCH_API void RemoveUnnecessaryOutputs(std::shared_ptr<Graph>& graph);
|
||||
|
||||
TORCH_API void RemoveUnnecessaryEmbeddingBagOutputs(
|
||||
std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user