mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67746 Test Plan: Visual inspection. Sandcastle. Reviewed By: zertosh Differential Revision: D31986646 fbshipit-source-id: 91885c20c3cead3853c49abb9fe0a94a67f33cc8
144 lines
4.0 KiB
C++
144 lines
4.0 KiB
C++
#include "deep_wide_pt.h"
|
|
|
|
#include <torch/csrc/jit/serialization/import_source.h>
|
|
#include <torch/script.h>
|
|
|
|
namespace {
|
|
// No ReplaceNaN (this removes the constant in the model)
|
|
const std::string deep_wide_pt = R"JIT(
|
|
class DeepAndWide(Module):
|
|
__parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ]
|
|
__buffers__ = []
|
|
_mu : Tensor
|
|
_sigma : Tensor
|
|
_fc_w : Tensor
|
|
_fc_b : Tensor
|
|
training : bool
|
|
def forward(self: __torch__.DeepAndWide,
|
|
ad_emb_packed: Tensor,
|
|
user_emb: Tensor,
|
|
wide: Tensor) -> Tuple[Tensor]:
|
|
_0 = self._fc_b
|
|
_1 = self._fc_w
|
|
_2 = self._sigma
|
|
wide_offset = torch.add(wide, self._mu, alpha=1)
|
|
wide_normalized = torch.mul(wide_offset, _2)
|
|
wide_preproc = torch.clamp(wide_normalized, 0., 10.)
|
|
user_emb_t = torch.transpose(user_emb, 1, 2)
|
|
dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
|
|
dp = torch.flatten(dp_unflatten, 1, -1)
|
|
input = torch.cat([dp, wide_preproc], 1)
|
|
fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1)
|
|
return (torch.sigmoid(fc1),)
|
|
)JIT";
|
|
|
|
const std::string trivial_model_1 = R"JIT(
|
|
def forward(self, a, b, c):
|
|
s = torch.tensor([[3, 3], [3, 3]])
|
|
return a + b * c + s
|
|
)JIT";
|
|
|
|
const std::string leaky_relu_model_const = R"JIT(
|
|
def forward(self, input):
|
|
x = torch.leaky_relu(input, 0.1)
|
|
x = torch.leaky_relu(x, 0.1)
|
|
x = torch.leaky_relu(x, 0.1)
|
|
x = torch.leaky_relu(x, 0.1)
|
|
return torch.leaky_relu(x, 0.1)
|
|
)JIT";
|
|
|
|
const std::string leaky_relu_model = R"JIT(
|
|
def forward(self, input, neg_slope):
|
|
x = torch.leaky_relu(input, neg_slope)
|
|
x = torch.leaky_relu(x, neg_slope)
|
|
x = torch.leaky_relu(x, neg_slope)
|
|
x = torch.leaky_relu(x, neg_slope)
|
|
return torch.leaky_relu(x, neg_slope)
|
|
)JIT";
|
|
|
|
void import_libs(
|
|
std::shared_ptr<at::CompilationUnit> cu,
|
|
const std::string& class_name,
|
|
const std::shared_ptr<torch::jit::Source>& src,
|
|
const std::vector<at::IValue>& tensor_table) {
|
|
torch::jit::SourceImporter si(
|
|
cu,
|
|
&tensor_table,
|
|
[&](const std::string& /* unused */)
|
|
-> std::shared_ptr<torch::jit::Source> { return src; },
|
|
/*version=*/2);
|
|
si.loadType(c10::QualifiedName(class_name));
|
|
}
|
|
} // namespace
|
|
|
|
torch::jit::Module getDeepAndWideSciptModel(int num_features) {
|
|
auto cu = std::make_shared<at::CompilationUnit>();
|
|
std::vector<at::IValue> constantTable;
|
|
import_libs(
|
|
cu,
|
|
"__torch__.DeepAndWide",
|
|
std::make_shared<torch::jit::Source>(deep_wide_pt),
|
|
constantTable);
|
|
c10::QualifiedName base("__torch__");
|
|
auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide"));
|
|
|
|
torch::jit::Module mod(cu, clstype);
|
|
|
|
mod.register_parameter("_mu", torch::randn({1, num_features}), false);
|
|
mod.register_parameter("_sigma", torch::randn({1, num_features}), false);
|
|
mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false);
|
|
mod.register_parameter("_fc_b", torch::randn({1}), false);
|
|
|
|
// mod.dump(true, true, true);
|
|
return mod;
|
|
}
|
|
|
|
torch::jit::Module getTrivialScriptModel() {
|
|
torch::jit::Module module("m");
|
|
module.define(trivial_model_1);
|
|
return module;
|
|
}
|
|
|
|
torch::jit::Module getLeakyReLUScriptModel() {
|
|
torch::jit::Module module("leaky_relu");
|
|
module.define(leaky_relu_model);
|
|
return module;
|
|
}
|
|
|
|
torch::jit::Module getLeakyReLUConstScriptModel() {
|
|
torch::jit::Module module("leaky_relu_const");
|
|
module.define(leaky_relu_model_const);
|
|
return module;
|
|
}
|
|
|
|
const std::string long_model = R"JIT(
|
|
def forward(self, a, b, c):
|
|
d = torch.relu(a * b)
|
|
e = torch.relu(a * c)
|
|
f = torch.relu(e * d)
|
|
g = torch.relu(f * f)
|
|
h = torch.relu(g * c)
|
|
return h
|
|
)JIT";
|
|
|
|
torch::jit::Module getLongScriptModel() {
|
|
torch::jit::Module module("m");
|
|
module.define(long_model);
|
|
return module;
|
|
}
|
|
|
|
const std::string signed_log1p_model = R"JIT(
|
|
def forward(self, a):
|
|
b = torch.abs(a)
|
|
c = torch.log1p(b)
|
|
d = torch.sign(a)
|
|
e = d * c
|
|
return e
|
|
)JIT";
|
|
|
|
torch::jit::Module getSignedLog1pModel() {
|
|
torch::jit::Module module("signed_log1p");
|
|
module.define(signed_log1p_model);
|
|
return module;
|
|
}
|