mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72071 Reviewed By: ngimel Differential Revision: D33946250 Pulled By: navahgar fbshipit-source-id: 3f3f92054174620025a9d71154d006f1738953e2 (cherry picked from commit d8b53598e92e8d2e050bc1d0cd070fbe8e2d77dd)
432 lines
14 KiB
C++
432 lines
14 KiB
C++
#include <benchmark/benchmark.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
#include <torch/torch.h>
|
|
#include "caffe2/operators/logit_op.h"
|
|
#include "caffe2/operators/tanh_op.h"
|
|
|
|
using namespace torch::jit;
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
void vectorize(tensorexpr::LoopNest* ln, tensorexpr::Tensor target, int width) {
|
|
auto loops = ln->getLoopStmtsFor(target);
|
|
ForPtr inner, tail;
|
|
ln->splitWithTail(loops[0], width, &inner, &tail);
|
|
ln->vectorize(inner);
|
|
}
|
|
|
|
void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor target) {
|
|
std::vector<ForPtr> loops = ln->getLoopStmtsFor(target);
|
|
ForPtr inner, tail;
|
|
ln->splitWithTail(loops[0], 16 * 8, &inner, &tail);
|
|
ForPtr outer = loops[0];
|
|
ln->vectorize(inner);
|
|
ln->splitWithTail(outer, 8, &inner, &tail);
|
|
StmtPtr unrolled;
|
|
LoopNest::fullUnroll(inner, &unrolled);
|
|
}
|
|
|
|
static void relu_nnc(benchmark::State& state) {
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
auto clamp = 0;
|
|
torch::jit::tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
|
|
auto A_elem = [&]() {
|
|
auto elem = A.load(i);
|
|
auto min = FloatImm::make(clamp);
|
|
return CompareSelect::make(elem, min, min, elem, kLT);
|
|
}();
|
|
return A_elem;
|
|
});
|
|
LoopNest ln({B});
|
|
optimizePointwise(&ln, B);
|
|
ln.prepareForCodegen();
|
|
StmtPtr s = ln.root_stmt();
|
|
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
|
std::vector<CodeGen::BufferArg> args;
|
|
args.emplace_back(B);
|
|
args.emplace_back(A);
|
|
args.emplace_back(N);
|
|
LLVMCodeGen cg(s, args);
|
|
at::Tensor A_t = torch::randn({state.range(0)});
|
|
at::Tensor B_t = torch::randn(state.range(0));
|
|
auto B_ref = at::relu(A_t);
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
TORCH_CHECK(at::allclose(B_t, B_ref));
|
|
for (auto _ : state) {
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
}
|
|
state.counters["log/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void log_nnc_sleef(benchmark::State& state) {
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
torch::jit::tensorexpr::Tensor B =
|
|
Compute("B", {N}, [&](const VarHandle& i) { return log(A.load(i)); });
|
|
LoopNest ln({B});
|
|
ln.prepareForCodegen();
|
|
vectorize(&ln, B, 8);
|
|
StmtPtr s = ln.root_stmt();
|
|
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
|
std::vector<CodeGen::BufferArg> args;
|
|
args.emplace_back(B);
|
|
args.emplace_back(A);
|
|
args.emplace_back(N);
|
|
LLVMCodeGen cg(s, args);
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
auto B_ref = at::log(A_t);
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
TORCH_CHECK(at::allclose(B_t, B_ref));
|
|
for (auto _ : state) {
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
}
|
|
state.counters["log/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void log_nnc_fast(benchmark::State& state) {
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
torch::jit::tensorexpr::Tensor B = Compute(
|
|
"B", {N}, [&](const VarHandle& i) { return fast_log(A.load(i)); });
|
|
LoopNest ln({B});
|
|
optimizePointwise(&ln, B);
|
|
ln.prepareForCodegen();
|
|
StmtPtr s = ln.root_stmt();
|
|
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
|
std::vector<CodeGen::BufferArg> args;
|
|
args.emplace_back(B);
|
|
args.emplace_back(A);
|
|
args.emplace_back(N);
|
|
LLVMCodeGen cg(s, args);
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
auto B_ref = at::log(A_t);
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
TORCH_CHECK(at::allclose(B_t, B_ref));
|
|
for (auto _ : state) {
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
}
|
|
state.counters["log/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void log_nnc_vml(benchmark::State& state) {
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
torch::jit::tensorexpr::Tensor B =
|
|
Compute("B", {N}, [&](const VarHandle& i) { return log_vml(A.load(i)); });
|
|
LoopNest ln({B});
|
|
vectorize(&ln, B, 8);
|
|
ln.prepareForCodegen();
|
|
StmtPtr s = ln.root_stmt();
|
|
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
|
std::vector<CodeGen::BufferArg> args;
|
|
args.emplace_back(B);
|
|
args.emplace_back(A);
|
|
args.emplace_back(N);
|
|
LLVMCodeGen cg(s, args);
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
auto B_ref = at::log(A_t);
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
TORCH_CHECK(at::allclose(B_t, B_ref));
|
|
for (auto _ : state) {
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
}
|
|
state.counters["log/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void log_aten(benchmark::State& state) {
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
for (auto _ : state) {
|
|
at::log_out(B_t, A_t);
|
|
}
|
|
state.counters["log/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void logit_nnc_sleef(benchmark::State& state) {
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
auto clamp = 1e-6f;
|
|
tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
|
|
auto A_elem = [&]() {
|
|
auto elem = A.load(i);
|
|
auto min = FloatImm::make(clamp);
|
|
auto max = FloatImm::make(1.0f - clamp);
|
|
elem = CompareSelect::make(elem, min, min, elem, kLT);
|
|
return CompareSelect::make(elem, max, max, elem, kGT);
|
|
}();
|
|
return log(A_elem / (FloatImm::make(1.0f) - A_elem));
|
|
});
|
|
LoopNest ln({B});
|
|
ln.prepareForCodegen();
|
|
optimizePointwise(&ln, B);
|
|
StmtPtr s = ln.root_stmt();
|
|
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
|
std::vector<CodeGen::BufferArg> args;
|
|
args.emplace_back(B);
|
|
args.emplace_back(A);
|
|
args.emplace_back(N);
|
|
LLVMCodeGen cg(s, args);
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
auto B_ref = at::logit(A_t, clamp);
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref)));
|
|
for (auto _ : state) {
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
}
|
|
state.counters["logit/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void logit_nnc_fast(benchmark::State& state) {
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
auto clamp = 1e-6f;
|
|
tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
|
|
auto A_elem = [&]() {
|
|
auto elem = A.load(i);
|
|
auto min = FloatImm::make(clamp);
|
|
auto max = FloatImm::make(1.0f - clamp);
|
|
elem = CompareSelect::make(elem, min, min, elem, kLT);
|
|
return CompareSelect::make(elem, max, max, elem, kGT);
|
|
}();
|
|
return fast_log(A_elem / (FloatImm::make(1.0f) - A_elem));
|
|
});
|
|
LoopNest ln({B});
|
|
ln.prepareForCodegen();
|
|
optimizePointwise(&ln, B);
|
|
StmtPtr s = ln.root_stmt();
|
|
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
|
std::vector<CodeGen::BufferArg> args;
|
|
args.emplace_back(B);
|
|
args.emplace_back(A);
|
|
args.emplace_back(N);
|
|
LLVMCodeGen cg(s, args);
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
auto B_ref = at::logit(A_t, clamp);
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref)));
|
|
for (auto _ : state) {
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
}
|
|
state.counters["logit/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void logit_nnc_vml(benchmark::State& state) {
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
auto clamp = 1e-6f;
|
|
tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) {
|
|
auto A_elem = [&]() {
|
|
auto elem = A.load(i);
|
|
auto min = FloatImm::make(clamp);
|
|
auto max = FloatImm::make(1.0f - clamp);
|
|
elem = CompareSelect::make(elem, min, min, elem, kLT);
|
|
return CompareSelect::make(elem, max, max, elem, kGT);
|
|
}();
|
|
return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem));
|
|
});
|
|
LoopNest ln({B});
|
|
ln.prepareForCodegen();
|
|
vectorize(&ln, B, 16);
|
|
StmtPtr s = ln.root_stmt();
|
|
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
|
std::vector<CodeGen::BufferArg> args;
|
|
args.emplace_back(B);
|
|
args.emplace_back(A);
|
|
args.emplace_back(N);
|
|
LLVMCodeGen cg(s, args);
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
auto B_ref = at::logit(A_t, clamp);
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref)));
|
|
for (auto _ : state) {
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
}
|
|
state.counters["logit/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void logit_aten(benchmark::State& state) {
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
auto clamp = 1e-6f;
|
|
for (auto _ : state) {
|
|
at::native::logit_out(A_t, clamp, B_t);
|
|
}
|
|
state.counters["logit/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
template <typename T>
|
|
void logit_caffe2_impl(int size, const T* X, T* Y, float eps_ = 1e-6f) {
|
|
using namespace caffe2;
|
|
ConstEigenVectorMap<T> X_vec(X, size);
|
|
EigenVectorMap<T> Y_vec(Y, size);
|
|
Y_vec = X_vec.array().min(static_cast<T>(1.0f - eps_));
|
|
Y_vec = Y_vec.array().max(eps_);
|
|
Y_vec = (Y_vec.array() / (T(1) - Y_vec.array())).log();
|
|
}
|
|
|
|
static void logit_caffe2(benchmark::State& state) {
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
at::Tensor B_ref = torch::randn({state.range(0)});
|
|
auto N = state.range(0);
|
|
auto X = A_t.data_ptr<float>();
|
|
auto Y = B_t.data_ptr<float>();
|
|
auto clamp = 1e-6f;
|
|
at::native::logit_out(A_t, clamp, B_ref);
|
|
logit_caffe2_impl(N, X, Y, clamp);
|
|
TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref)));
|
|
|
|
for (auto _ : state) {
|
|
logit_caffe2_impl(N, X, Y, clamp);
|
|
}
|
|
|
|
state.counters["logit/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void tanh_nnc_fast(benchmark::State& state) {
|
|
auto N = VarHandle("N", kInt);
|
|
BufHandle A("A", {N}, kFloat);
|
|
torch::jit::tensorexpr::Tensor B = Compute(
|
|
"B", {N}, [&](const VarHandle& i) { return fast_tanh(A.load(i)); });
|
|
LoopNest ln({B});
|
|
optimizePointwise(&ln, B);
|
|
ln.prepareForCodegen();
|
|
StmtPtr s = ln.root_stmt();
|
|
s = torch::jit::tensorexpr::IRSimplifier::simplify(s);
|
|
std::vector<CodeGen::BufferArg> args;
|
|
args.emplace_back(B);
|
|
args.emplace_back(A);
|
|
args.emplace_back(N);
|
|
LLVMCodeGen cg(s, args);
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
auto B_ref = at::tanh(A_t);
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
TORCH_CHECK(at::allclose(B_t, B_ref, 1e-3f, 1e-6f));
|
|
for (auto _ : state) {
|
|
cg.call({B_t.data_ptr<float>(), A_t.data_ptr<float>(), state.range(0)});
|
|
}
|
|
state.counters["tanh/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void tanh_aten(benchmark::State& state) {
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
for (auto _ : state) {
|
|
at::tanh_out(A_t, B_t);
|
|
}
|
|
state.counters["tanh/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
static void tanh_caffe2(benchmark::State& state) {
|
|
at::Tensor A_t = torch::abs(torch::randn({state.range(0)}));
|
|
at::Tensor B_t = torch::randn({state.range(0)});
|
|
at::Tensor B_ref = torch::randn({state.range(0)});
|
|
|
|
auto N = state.range(0);
|
|
auto X = A_t.data_ptr<float>();
|
|
auto Y = B_t.data_ptr<float>();
|
|
caffe2::CPUContext c;
|
|
auto tanh = caffe2::TanhFunctor<caffe2::CPUContext>();
|
|
at::tanh_out(A_t, B_ref);
|
|
tanh(N, X, Y, &c);
|
|
TORCH_CHECK(at::native::allclose(B_t, B_ref, 1e-3f, 1e-6f));
|
|
|
|
for (auto _ : state) {
|
|
tanh(N, X, Y, &c);
|
|
}
|
|
state.counters["tanh/s"] = benchmark::Counter(
|
|
uint64_t(state.range(0) * state.iterations()),
|
|
benchmark::Counter::kIsRate);
|
|
}
|
|
|
|
BENCHMARK(relu_nnc)->Args({2 << 5})->Args({2 << 8})->Args({2 << 12})->Args(
|
|
{2 << 14});
|
|
BENCHMARK(log_nnc_sleef)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|
|
BENCHMARK(log_nnc_fast)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|
|
BENCHMARK(log_nnc_vml)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|
|
BENCHMARK(log_aten)->Args({2 << 5})->Args({2 << 8})->Args({2 << 12})->Args(
|
|
{2 << 14});
|
|
BENCHMARK(logit_nnc_sleef)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|
|
BENCHMARK(logit_nnc_fast)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|
|
BENCHMARK(logit_nnc_vml)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|
|
BENCHMARK(logit_aten)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|
|
BENCHMARK(logit_caffe2)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|
|
BENCHMARK(tanh_nnc_fast)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|
|
BENCHMARK(tanh_aten)->Args({2 << 5})->Args({2 << 8})->Args({2 << 12})->Args(
|
|
{2 << 14});
|
|
BENCHMARK(tanh_caffe2)
|
|
->Args({2 << 5})
|
|
->Args({2 << 8})
|
|
->Args({2 << 12})
|
|
->Args({2 << 14});
|