mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR enables Wmissing-prototypes in torch_cpu except some generated cpp files and the mps and metal,vulkan backends and caffe2 sources. Pull Request resolved: https://github.com/pytorch/pytorch/pull/104138 Approved by: https://github.com/albanD, https://github.com/malfet
316 lines
8.9 KiB
C++
316 lines
8.9 KiB
C++
#include <torch/csrc/jit/runtime/static/te_wrapper.h>
|
|
|
|
#include <ATen/CPUFunctions.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
#include <torch/csrc/jit/tensorexpr/expr.h>
|
|
#include <torch/csrc/jit/tensorexpr/operators/misc.h>
|
|
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
// Use the width of an AVX-512 vector by default; this happens to work OK for
|
|
// AVX2 as well. Some ops benefit from using multiple AVX ports, in which case
|
|
// they are vectorized by twice this constant. An exception is logit, since it
|
|
// contains FP divide, which is single-ported.
|
|
static constexpr int kVectorWidth = 16;
|
|
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
|
|
void TEWrapper::update(std::unique_ptr<LLVMCodeGen>&& cg_) {
|
|
cg = std::move(cg_);
|
|
}
|
|
|
|
void TEWrapper::call(const std::vector<void*>& args) {
|
|
cg->call_raw(args);
|
|
}
|
|
|
|
static void optimizePointwise(LoopNest* ln, Tensor target, int width) {
|
|
std::vector<ForPtr> loops = ln->getLoopStmtsFor(target);
|
|
ForPtr inner, tail;
|
|
TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op");
|
|
ln->splitWithTail(loops[0], width, &inner, &tail);
|
|
ln->vectorize(inner);
|
|
}
|
|
|
|
static std::shared_ptr<TEWrapper> wrapTECompute(
|
|
std::shared_ptr<TEWrapper> wrap,
|
|
Tensor out,
|
|
std::vector<CodeGen::BufferArg> args,
|
|
int width = kVectorWidth) {
|
|
LoopNest ln({out});
|
|
optimizePointwise(&ln, out, width);
|
|
ln.prepareForCodegen();
|
|
StmtPtr s = ln.root_stmt();
|
|
s = IRSimplifier::simplify(s);
|
|
args.insert(args.begin(), out);
|
|
auto cg = std::make_unique<LLVMCodeGen>(s, args);
|
|
cg->cleanup_memory();
|
|
wrap->update(std::move(cg));
|
|
return wrap;
|
|
}
|
|
|
|
static std::shared_ptr<TEWrapper> wrapTECompute(
|
|
std::shared_ptr<TEWrapper> wrap,
|
|
LoopNest* ln,
|
|
std::vector<CodeGen::BufferArg> args) {
|
|
auto cg = std::make_unique<LLVMCodeGen>(ln->root_stmt(), args);
|
|
wrap->update(std::move(cg));
|
|
return wrap;
|
|
}
|
|
|
|
#else
|
|
|
|
void TEWrapper::call(const std::vector<void*>& args) {
|
|
DCHECK(0 && "Invalid call");
|
|
}
|
|
|
|
static std::shared_ptr<TEWrapper> wrapTECompute(
|
|
std::shared_ptr<TEWrapper> wrap,
|
|
Tensor out,
|
|
std::vector<CodeGen::BufferArg> args,
|
|
int width = kVectorWidth) {
|
|
return wrap;
|
|
}
|
|
|
|
static std::shared_ptr<TEWrapper> wrapTECompute(
|
|
std::shared_ptr<TEWrapper> wrap,
|
|
LoopNest* ln,
|
|
std::vector<CodeGen::BufferArg> args) {
|
|
return wrap;
|
|
}
|
|
|
|
#endif
|
|
|
|
namespace {
|
|
|
|
std::mutex& getNNCCacheMutex() {
|
|
static std::mutex nncCacheMutex;
|
|
return nncCacheMutex;
|
|
}
|
|
|
|
c10::FastMap<NodeKind, std::shared_ptr<TEWrapper>>& getNNCCache() {
|
|
static c10::FastMap<NodeKind, std::shared_ptr<TEWrapper>> nncCache;
|
|
return nncCache;
|
|
}
|
|
|
|
std::shared_ptr<TEWrapper> lookupNNCCache(NodeKind kind) {
|
|
std::lock_guard<std::mutex> lock(getNNCCacheMutex());
|
|
auto it = getNNCCache().find(kind);
|
|
if (it != getNNCCache().end()) {
|
|
return it->second;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
void updateNNCCache(NodeKind kind, std::shared_ptr<TEWrapper> code) {
|
|
std::lock_guard<std::mutex> lock(getNNCCacheMutex());
|
|
getNNCCache()[kind] = code;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::shared_ptr<TEWrapper> createDiv() {
|
|
auto wrap = lookupNNCCache(aten::div);
|
|
if (wrap) {
|
|
return wrap;
|
|
}
|
|
wrap = std::make_shared<TEWrapper>();
|
|
|
|
auto dim = VarHandle("dim", kInt);
|
|
auto mode = VarHandle("mode", kInt);
|
|
BufHandle A("A", {dim}, kFloat);
|
|
BufHandle B("B", {dim}, kFloat);
|
|
|
|
using axis = const VarHandle&;
|
|
Tensor C = Compute("C", {dim}, [&](axis x) {
|
|
auto true_div_result = A.load(x) / B.load(x);
|
|
|
|
auto mode_default = IntImm::make(0);
|
|
auto mode_trunc = IntImm::make(1);
|
|
auto mode_floor = IntImm::make(2);
|
|
|
|
// this is a glorified ternary choice operator train
|
|
return CompareSelect::make(
|
|
mode,
|
|
mode_default,
|
|
true_div_result,
|
|
CompareSelect::make(
|
|
mode,
|
|
mode_trunc,
|
|
trunc(true_div_result),
|
|
floor(true_div_result),
|
|
kEQ),
|
|
kEQ);
|
|
});
|
|
|
|
wrap = wrapTECompute(wrap, C, {A, B, mode, dim});
|
|
|
|
updateNNCCache(aten::div, wrap);
|
|
return wrap;
|
|
}
|
|
|
|
std::shared_ptr<TEWrapper> createLogit() {
|
|
auto wrap = lookupNNCCache(aten::logit);
|
|
if (wrap) {
|
|
return wrap;
|
|
}
|
|
wrap = std::make_shared<TEWrapper>();
|
|
auto N = VarHandle("N", kInt);
|
|
auto C = VarHandle("C", kFloat);
|
|
BufHandle A("A", {N}, kFloat);
|
|
Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
|
|
auto A_elem = [&]() {
|
|
auto elem = A.load(i);
|
|
auto one = FloatImm::make(1.0f);
|
|
const auto& min = C;
|
|
auto max = one - C;
|
|
elem = CompareSelect::make(elem, min, min, elem, kLT);
|
|
return CompareSelect::make(elem, max, max, elem, kGT);
|
|
}();
|
|
return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem));
|
|
});
|
|
wrap = wrapTECompute(wrap, B, {A, N, C});
|
|
updateNNCCache(aten::logit, wrap);
|
|
return wrap;
|
|
}
|
|
|
|
std::shared_ptr<TEWrapper> createRelu() {
|
|
auto wrap = lookupNNCCache(aten::relu);
|
|
if (wrap) {
|
|
return wrap;
|
|
}
|
|
wrap = std::make_shared<TEWrapper>();
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
|
|
auto zero = FloatImm::make(0.f);
|
|
auto a = A.load(i);
|
|
return CompareSelect::make(a, zero, zero, a, kLT);
|
|
});
|
|
wrap = wrapTECompute(wrap, B, {A, N});
|
|
updateNNCCache(aten::relu, wrap);
|
|
return wrap;
|
|
}
|
|
|
|
std::shared_ptr<TEWrapper> createTanh() {
|
|
auto wrap = lookupNNCCache(aten::tanh);
|
|
if (wrap) {
|
|
return wrap;
|
|
}
|
|
wrap = std::make_shared<TEWrapper>();
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
|
|
auto a = A.load(i);
|
|
return fast_tanh(a);
|
|
});
|
|
wrap = wrapTECompute(wrap, B, {A, N});
|
|
updateNNCCache(aten::tanh, wrap);
|
|
return wrap;
|
|
}
|
|
|
|
std::shared_ptr<TEWrapper> createSigmoid() {
|
|
auto wrap = lookupNNCCache(aten::sigmoid);
|
|
if (wrap) {
|
|
return wrap;
|
|
}
|
|
wrap = std::make_shared<TEWrapper>();
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
Tensor B = Compute(
|
|
"B", {N}, [&](const VarHandle& i) { return fast_sigmoid(A.load(i)); });
|
|
wrap = wrapTECompute(wrap, B, {A, N});
|
|
updateNNCCache(aten::sigmoid, wrap);
|
|
return wrap;
|
|
}
|
|
|
|
std::shared_ptr<TEWrapper> createClamp() {
|
|
static auto clamp_symbol = c10::Symbol::fromQualString("aten::clamp");
|
|
auto wrap = lookupNNCCache(clamp_symbol);
|
|
if (wrap) {
|
|
return wrap;
|
|
}
|
|
wrap = std::make_shared<TEWrapper>();
|
|
auto N = VarHandle("N", kInt);
|
|
auto min_handle = VarHandle("min", kFloat);
|
|
auto max_handle = VarHandle("max", kFloat);
|
|
|
|
BufHandle A("A", {N}, kFloat);
|
|
Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
|
|
auto a = A.load(i);
|
|
return tensorexpr::clamp(min_handle, max_handle, a);
|
|
});
|
|
wrap = wrapTECompute(wrap, result, {A, min_handle, max_handle, N});
|
|
updateNNCCache(clamp_symbol, wrap);
|
|
return wrap;
|
|
}
|
|
|
|
std::shared_ptr<TEWrapper> createClampNanToNum() {
|
|
static auto symbol =
|
|
c10::Symbol::fromQualString("static_runtime::clamp_nan_to_num");
|
|
auto wrap = lookupNNCCache(symbol);
|
|
if (wrap) {
|
|
return wrap;
|
|
}
|
|
wrap = std::make_shared<TEWrapper>();
|
|
auto N = VarHandle("N", kInt);
|
|
auto min_handle = VarHandle("min", kFloat);
|
|
auto max_handle = VarHandle("max", kFloat);
|
|
auto nan_replace_val = VarHandle("nan_replace_val", kFloat);
|
|
|
|
BufHandle A("A", {N}, kFloat);
|
|
Tensor result = Compute("aten_clamp", {N}, [&](const VarHandle& i) {
|
|
auto a = A.load(i);
|
|
auto clamp = tensorexpr::clamp(min_handle, max_handle, a);
|
|
auto is_nan = tensorexpr::isnan(clamp);
|
|
auto nans_replaced =
|
|
tensorexpr::CompareSelect::make(is_nan, 1, nan_replace_val, clamp, kEQ);
|
|
return nans_replaced;
|
|
});
|
|
wrap = wrapTECompute(
|
|
wrap, result, {A, min_handle, max_handle, nan_replace_val, N});
|
|
updateNNCCache(symbol, wrap);
|
|
return wrap;
|
|
}
|
|
|
|
std::shared_ptr<TEWrapper> createSignedLog1p() {
|
|
static auto signed_log1p_symbol =
|
|
c10::Symbol::fromQualString("static_runtime::signed_log1p");
|
|
auto wrap = lookupNNCCache(signed_log1p_symbol);
|
|
if (wrap) {
|
|
return wrap;
|
|
}
|
|
wrap = std::make_shared<TEWrapper>();
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
Tensor abs_result = Compute("aten_abs", {N}, [&](const VarHandle& i) {
|
|
return tensorexpr::abs(A.load(i));
|
|
});
|
|
Tensor log1p_result = Compute("aten_log1p", {N}, [&](const VarHandle& i) {
|
|
return log1p(abs_result.load(i));
|
|
});
|
|
Tensor sign = computeSign({A}, {N});
|
|
Tensor output = Compute("aten_mul", {N}, [&](const VarHandle& i) {
|
|
return sign.load(i) * log1p_result.load(i);
|
|
});
|
|
LoopNest ln({output}, {abs_result, log1p_result, sign, output});
|
|
GRAPH_DEBUG("Original stmt: ", *ln.root_stmt());
|
|
ln.inlineIntermediateBufs(true);
|
|
ln.prepareForCodegen();
|
|
ln.simplify();
|
|
ln.vectorizeInnerLoops();
|
|
ln.simplify();
|
|
GRAPH_DEBUG("Final stmt: ", *ln.root_stmt());
|
|
wrap = wrapTECompute(wrap, &ln, {output, A, N});
|
|
updateNNCCache(signed_log1p_symbol, wrap);
|
|
return wrap;
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|