mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67746 Test Plan: Visual inspection. Sandcastle. Reviewed By: zertosh Differential Revision: D31986646 fbshipit-source-id: 91885c20c3cead3853c49abb9fe0a94a67f33cc8
142 lines
4.9 KiB
C++
142 lines
4.9 KiB
C++
#pragma once
|
|
|
|
#include <ATen/CPUFunctions.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#include <torch/torch.h>
|
|
|
|
struct DeepAndWide : torch::nn::Module {
|
|
DeepAndWide(int num_features = 50) {
|
|
mu_ = register_parameter("mu_", torch::randn({1, num_features}));
|
|
sigma_ = register_parameter("sigma_", torch::randn({1, num_features}));
|
|
fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1}));
|
|
fc_b_ = register_parameter("fc_b_", torch::randn({1}));
|
|
}
|
|
|
|
torch::Tensor forward(
|
|
torch::Tensor ad_emb_packed,
|
|
torch::Tensor user_emb,
|
|
torch::Tensor wide) {
|
|
auto wide_offset = wide + mu_;
|
|
auto wide_normalized = wide_offset * sigma_;
|
|
auto wide_noNaN = wide_normalized;
|
|
// Placeholder for ReplaceNaN
|
|
auto wide_preproc = torch::clamp(wide_noNaN, -10.0, 10.0);
|
|
|
|
auto user_emb_t = torch::transpose(user_emb, 1, 2);
|
|
auto dp_unflatten = torch::bmm(ad_emb_packed, user_emb_t);
|
|
auto dp = torch::flatten(dp_unflatten, 1);
|
|
auto input = torch::cat({dp, wide_preproc}, 1);
|
|
auto fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_);
|
|
auto pred = torch::sigmoid(fc1);
|
|
return pred;
|
|
}
|
|
torch::Tensor mu_, sigma_, fc_w_, fc_b_;
|
|
};
|
|
|
|
// Implementation using native functions and pre-allocated tensors.
|
|
// It could be used as a "speed of light" for static runtime.
|
|
struct DeepAndWideFast : torch::nn::Module {
|
|
DeepAndWideFast(int num_features = 50) {
|
|
mu_ = register_parameter("mu_", torch::randn({1, num_features}));
|
|
sigma_ = register_parameter("sigma_", torch::randn({1, num_features}));
|
|
fc_w_ = register_parameter("fc_w_", torch::randn({1, num_features + 1}));
|
|
fc_b_ = register_parameter("fc_b_", torch::randn({1}));
|
|
allocated = false;
|
|
prealloc_tensors = {};
|
|
}
|
|
|
|
torch::Tensor forward(
|
|
torch::Tensor ad_emb_packed,
|
|
torch::Tensor user_emb,
|
|
torch::Tensor wide) {
|
|
torch::NoGradGuard no_grad;
|
|
if (!allocated) {
|
|
auto wide_offset = at::add(wide, mu_);
|
|
auto wide_normalized = at::mul(wide_offset, sigma_);
|
|
// Placeholder for ReplaceNaN
|
|
auto wide_preproc = at::cpu::clamp(wide_normalized, -10.0, 10.0);
|
|
|
|
auto user_emb_t = at::native::transpose(user_emb, 1, 2);
|
|
auto dp_unflatten = at::cpu::bmm(ad_emb_packed, user_emb_t);
|
|
// auto dp = at::native::flatten(dp_unflatten, 1);
|
|
auto dp = dp_unflatten.view({dp_unflatten.size(0), 1});
|
|
auto input = at::native::_cat_cpu({dp, wide_preproc}, 1);
|
|
|
|
// fc1 = torch::nn::functional::linear(input, fc_w_, fc_b_);
|
|
fc_w_t_ = torch::t(fc_w_);
|
|
auto fc1 = torch::addmm(fc_b_, input, fc_w_t_);
|
|
|
|
auto pred = at::cpu::sigmoid(fc1);
|
|
|
|
prealloc_tensors = {
|
|
wide_offset,
|
|
wide_normalized,
|
|
wide_preproc,
|
|
user_emb_t,
|
|
dp_unflatten,
|
|
dp,
|
|
input,
|
|
fc1,
|
|
pred};
|
|
allocated = true;
|
|
|
|
return pred;
|
|
} else {
|
|
// Potential optimization: add and mul could be fused together (e.g. with
|
|
// Eigen).
|
|
at::add_out(prealloc_tensors[0], wide, mu_);
|
|
at::mul_out(prealloc_tensors[1], prealloc_tensors[0], sigma_);
|
|
|
|
at::native::clip_out(
|
|
prealloc_tensors[1], -10.0, 10.0, prealloc_tensors[2]);
|
|
|
|
// Potential optimization: original tensor could be pre-transposed.
|
|
// prealloc_tensors[3] = at::native::transpose(user_emb, 1, 2);
|
|
if (prealloc_tensors[3].data_ptr() != user_emb.data_ptr()) {
|
|
auto sizes = user_emb.sizes();
|
|
auto strides = user_emb.strides();
|
|
prealloc_tensors[3].set_(
|
|
user_emb.storage(),
|
|
0,
|
|
{sizes[0], sizes[2], sizes[1]},
|
|
{strides[0], strides[2], strides[1]});
|
|
}
|
|
|
|
// Potential optimization: call MKLDNN directly.
|
|
at::cpu::bmm_out(ad_emb_packed, prealloc_tensors[3], prealloc_tensors[4]);
|
|
|
|
if (prealloc_tensors[5].data_ptr() != prealloc_tensors[4].data_ptr()) {
|
|
// in unlikely case that the input tensor changed we need to
|
|
// reinitialize the view
|
|
prealloc_tensors[5] =
|
|
prealloc_tensors[4].view({prealloc_tensors[4].size(0), 1});
|
|
}
|
|
|
|
// Potential optimization: we can replace cat with carefully constructed
|
|
// tensor views on the output that are passed to the _out ops above.
|
|
at::native::_cat_out_cpu(
|
|
{prealloc_tensors[5], prealloc_tensors[2]}, 1, prealloc_tensors[6]);
|
|
at::cpu::addmm_out(
|
|
prealloc_tensors[7], fc_b_, prealloc_tensors[6], fc_w_t_, 1, 1);
|
|
at::cpu::sigmoid_out(prealloc_tensors[7], prealloc_tensors[8]);
|
|
|
|
return prealloc_tensors[8];
|
|
}
|
|
}
|
|
torch::Tensor mu_, sigma_, fc_w_, fc_b_, fc_w_t_;
|
|
std::vector<torch::Tensor> prealloc_tensors;
|
|
bool allocated = false;
|
|
};
|
|
|
|
torch::jit::Module getDeepAndWideSciptModel(int num_features = 50);
|
|
|
|
torch::jit::Module getTrivialScriptModel();
|
|
|
|
torch::jit::Module getLeakyReLUScriptModel();
|
|
|
|
torch::jit::Module getLeakyReLUConstScriptModel();
|
|
|
|
torch::jit::Module getLongScriptModel();
|
|
|
|
torch::jit::Module getSignedLog1pModel();
|