mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: We also fix any existing issues. Note that we only do this for the CPU build because nvcc is considered a C++ toolchain but it does not have the same flag support. Adding flags to the GPU build will cause nvcc errors. Test Plan: Built locally, rely on CI to confirm. Reviewers: malfet Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/79154 Approved by: https://github.com/seemethere, https://github.com/osalpekar, https://github.com/albanD
543 lines
22 KiB
C++
543 lines
22 KiB
C++
// *** Tensor Expressions ***
|
|
//
|
|
// This tutorial covers basics of NNC's tensor expressions, shows basic APIs to
|
|
// work with them, and outlines how they are used in the overall TorchScript
|
|
// compilation pipeline. This doc is permanently a "work in progress" since NNC
|
|
// is under active development and things change fast.
|
|
//
|
|
// This Tutorial's code is compiled in the standard pytorch build, and the
|
|
// executable can be found in `build/bin/tutorial_tensorexpr`.
|
|
//
|
|
// *** What is NNC ***
|
|
//
|
|
// NNC stands for Neural Net Compiler. It is a component of TorchScript JIT
|
|
// and it performs on-the-fly code generation for kernels, which are often a
|
|
// combination of multiple aten (torch) operators.
|
|
//
|
|
// When the JIT interpreter executes a torchscript model, it automatically
|
|
// extracts subgraphs from the torchscript IR graph for which specialized code
|
|
// can be JIT generated. This usually improves performance as the 'combined'
|
|
// kernel created from the subgraph could avoid unnecessary memory traffic that
|
|
// is unavoidable when the subgraph is interpreted as-is, operator by operator.
|
|
// This optimization is often referred to as 'fusion'. Relatedly, the process of
|
|
// finding and extracting subgraphs suitable for NNC code generation is done by
|
|
// a JIT pass called 'fuser'.
|
|
//
|
|
// *** What is TE ***
|
|
//
|
|
// TE stands for Tensor Expressions. TE is a commonly used approach for
|
|
// compiling kernels performing tensor (~matrix) computation. The idea behind it
|
|
// is that operators are represented as a mathematical formula describing what
|
|
// computation they do (as TEs) and then the TE engine can perform mathematical
|
|
// simplification and other optimizations using those formulas and eventually
|
|
// generate executable code that would produce the same results as the original
|
|
// sequence of operators, but more efficiently.
|
|
//
|
|
// NNC's design and implementation of TE was heavily inspired by Halide and TVM
|
|
// projects.
|
|
#include <iostream>
|
|
#include <string>
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/tensorexpr/eval.h>
|
|
#include <torch/csrc/jit/tensorexpr/expr.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/stmt.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
#include <torch/torch.h>
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
|
|
// Helper function to print a snippet from a big multi-line string
|
|
static void printLinesToFrom(const std::string& input_str, int from, int to);
|
|
|
|
#endif
|
|
|
|
int main(int argc, char* argv[]) {
|
|
std::cout << "*** Structure of tensor expressions and statements ***"
|
|
<< std::endl;
|
|
{
|
|
// A tensor expression is a tree of expressions. Each expression has a type,
|
|
// and that type defines what sub-expressions the current expression has.
|
|
// For instance, an expression of type 'Mul' would have a type 'kMul' and
|
|
// two subexpressions: LHS and RHS. Each of these two sub-expressions could
|
|
// also be a 'Mul' or some other expression.
|
|
//
|
|
// Let's construct a simple TE:
|
|
ExprPtr lhs = alloc<IntImm>(5);
|
|
ExprPtr rhs = alloc<Var>("x", kInt);
|
|
ExprPtr mul = alloc<Mul>(lhs, rhs);
|
|
std::cout << "Tensor expression: " << *mul << std::endl;
|
|
// Prints: Tensor expression: 5 * x
|
|
|
|
// Here we created an expression representing a 5*x computation, where x is
|
|
// an int variable.
|
|
|
|
// Another, probably a more convenient, way to construct tensor expressions
|
|
// is to use so called expression handles (as opposed to raw expressions
|
|
// like we did in the previous example). Expression handles overload common
|
|
// operations and allow us to express the same semantics in a more natural
|
|
// way:
|
|
ExprHandle l = 5;
|
|
ExprHandle r = Var::make("x", kInt);
|
|
ExprHandle m = l * r;
|
|
std::cout << "Tensor expression: " << *m.node() << std::endl;
|
|
// Prints: Tensor expression: 5 * x
|
|
|
|
// Converting from handles to raw expressions and back is easy:
|
|
ExprHandle handle = Var::make("x", kInt);
|
|
ExprPtr raw_expr_from_handle = handle.node();
|
|
ExprPtr raw_expr = alloc<Var>("x", kInt);
|
|
ExprHandle handle_from_raw_expr = ExprHandle(raw_expr);
|
|
|
|
// We could construct arbitrarily complex expressions using mathematical
|
|
// and logical operations, casts between various data types, and a bunch of
|
|
// intrinsics.
|
|
ExprHandle a = Var::make("a", kInt);
|
|
ExprHandle b = Var::make("b", kFloat);
|
|
ExprHandle c = Var::make("c", kFloat);
|
|
ExprHandle x = ExprHandle(5) * a + b / (sigmoid(c) - 3.0f);
|
|
std::cout << "Tensor expression: " << *x.node() << std::endl;
|
|
// Prints: Tensor expression: float(5 * a) + b / ((sigmoid(c)) - 3.f)
|
|
|
|
// An ultimate purpose of tensor expressions is to optimize tensor
|
|
// computations, and in order to represent accesses to tensors data, there
|
|
// is a special kind of expression - a load.
|
|
// To construct a load we need two pieces: the base and the indices. The
|
|
// base of a load is a Buf expression, which could be thought of as a
|
|
// placeholder similar to Var, but with dimensions info.
|
|
//
|
|
// Let's construct a simple load:
|
|
BufHandle A("A", {64, 32}, kInt);
|
|
VarPtr i_var = alloc<Var>("i", kInt), j_var = alloc<Var>("j", kInt);
|
|
ExprHandle i(i_var), j(j_var);
|
|
ExprHandle load = Load::make(A.dtype(), A, {i, j});
|
|
std::cout << "Tensor expression: " << *load.node() << std::endl;
|
|
// Prints: Tensor expression: A[i, j]
|
|
|
|
// Tensor Expressions constitute Tensor Statements, which are used to
|
|
// represent computation of a given operator or a group of operators from a
|
|
// fusion group.
|
|
//
|
|
// There are three main kinds of tensor statements:
|
|
// - block
|
|
// - store
|
|
// - loop
|
|
//
|
|
// A Store represents a store to a single element of a tensor (or to a
|
|
// group of elements if it's a vectorized store). Store statements,
|
|
// similarly to Load expressions, have a base and indices, but on top of
|
|
// that they also include a value - an expression representing what needs
|
|
// to be stored at the given memory location. Let's create a Store stmt:
|
|
StmtPtr store_a = Store::make(A, {i, j}, i + j);
|
|
std::cout << "Store statement: " << *store_a << std::endl;
|
|
// Prints: Store statement: A[i, j] = i + j;
|
|
|
|
// An operator fills the entire tensor, not just a single element, and to
|
|
// represent this we need to use For stmt: let's wrap our store stmt with
|
|
// two nested loops to represent that variables i and j need to iterate
|
|
// over some ranges.
|
|
ForPtr loop_j_a = For::make(VarHandle(j_var), 0, 32, store_a);
|
|
ForPtr loop_i_a = For::make(VarHandle(i_var), 0, 64, loop_j_a);
|
|
|
|
std::cout << "Nested for loops: " << std::endl << *loop_i_a << std::endl;
|
|
// Prints:
|
|
// Nested for loops:
|
|
// for (const auto i : c10::irange(64)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// A[i, j] = i + j;
|
|
// }
|
|
// }
|
|
|
|
// A Block statement is used when we need a sequence of other statements.
|
|
// E.g. if a fusion group contains several operators, we initially define
|
|
// separate loopnest for each of them and put them all into a common block:
|
|
BufHandle B("B", {64, 32}, kInt);
|
|
StmtPtr store_b = Store::make(B, {i, j}, A.load(i, j));
|
|
ForPtr loop_j_b = For::make(VarHandle(j_var), 0, 32, store_b);
|
|
ForPtr loop_i_b = For::make(VarHandle(i_var), 0, 64, loop_j_b);
|
|
|
|
BlockPtr block = Block::make({loop_i_a, loop_i_b});
|
|
std::cout << "Compound Block statement: " << std::endl
|
|
<< *block << std::endl;
|
|
// Prints:
|
|
// Compound Block statement:
|
|
// {
|
|
// for (const auto i : c10::irange(64)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// A[i, j] = i + j;
|
|
// }
|
|
// }
|
|
// for (const auto i : c10::irange(64)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// B[i, j] = A[i, j];
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// Manually constructing nested loops and blocks to represent a computation
|
|
// might be laborious, and instead we can use a 'Compute' API. This API
|
|
// requires us to specify dimensions and a lambda to compute a single
|
|
// element of the resulting tensor and returns a `Tensor` structure. This
|
|
// structure is simply a pair of a buffer that was created to represent the
|
|
// result of the computation (BufPtr) and a statement representing the
|
|
// computation itself (StmtPtr).
|
|
Tensor C =
|
|
Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return i * j;
|
|
});
|
|
std::cout << "Stmt produced by 'Compute' API: " << std::endl
|
|
<< *C.stmt() << std::endl;
|
|
// Prints:
|
|
// Stmt produced by 'Compute' API:
|
|
// for (const auto i : c10::irange(64)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// C[i, j] = i * j;
|
|
// }
|
|
// }
|
|
|
|
// To construct statements to represent computations with reductions, we
|
|
// can use a 'Reduce' API - it is similar to 'Compute' but takes a couple
|
|
// of extra arguments defining how to perform the reduction. Let's define a
|
|
// simple 2D sum of C using that:
|
|
Tensor D = Reduce(
|
|
"D",
|
|
{},
|
|
Sum(),
|
|
[&](const VarHandle& i, const VarHandle& j) { return C.load(i, j); },
|
|
{64, 32});
|
|
std::cout << "Stmt produced by 'Reduce' API: " << std::endl
|
|
<< *D.stmt() << std::endl;
|
|
}
|
|
|
|
std::cout << "*** Loopnests transformations ***" << std::endl;
|
|
{
|
|
// When a statement for the computation is generated, we might want to
|
|
// apply some optimizations to it. These transformations allow us to end up
|
|
// with a statement producing the same results, but more efficiently.
|
|
//
|
|
// Let's look at a couple of transformations that are used in NNC. We will
|
|
// begin with constructing a Block statement like we did before.
|
|
|
|
Tensor C =
|
|
Compute("C", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return i * (j + 1);
|
|
});
|
|
BufHandle c_buf(C.buf());
|
|
Tensor D =
|
|
Compute("D", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return c_buf.load(i, j) - i;
|
|
});
|
|
StmtPtr block = Block::make({C.stmt(), D.stmt()});
|
|
std::cout << "Stmt produced by 'Compute' API: " << std::endl
|
|
<< *block << std::endl;
|
|
// Prints:
|
|
// Stmt produced by 'Compute' API:
|
|
// {
|
|
// for (const auto i : c10::irange(64)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// C[i, j] = i * (j + 1);
|
|
// }
|
|
// }
|
|
// for (const auto i_1 : c10::irange(64)) {
|
|
// for (const auto j_1 : c10::irange(32)) {
|
|
// D[i_1, j_1] = (C[i_1, j_1]) - i_1;
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// One transformation we can apply to this computation is inlining: i.e.
|
|
// taking the expression that defines values of C and substituting a load
|
|
// from C with it.
|
|
// To do that, we first need to create a special object called LoopNest -
|
|
// all transformations are methods of this class. To create a loopnest we
|
|
// need to provide a list of output buffers and the root statement:
|
|
LoopNest nest(block, {D.buf()});
|
|
|
|
// We can always retrieve the Stmt back from LoopNest:
|
|
std::cout << "LoopNest root stmt: " << std::endl
|
|
<< *nest.root_stmt() << std::endl;
|
|
// Prints:
|
|
// LoopNest root stmt:
|
|
// {
|
|
// for (const auto i : c10::irange(64)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// C[i, j] = i * (j + 1);
|
|
// }
|
|
// }
|
|
// for (const auto i_1 : c10::irange(64)) {
|
|
// for (const auto j_1 : c10::irange(32)) {
|
|
// D[i_1, j_1] = (C[i_1, j_1]) - i_1;
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// Now we can apply the inlining transformation:
|
|
nest.computeInline(C.buf());
|
|
std::cout << "Stmt after inlining:" << std::endl
|
|
<< *nest.root_stmt() << std::endl;
|
|
// Prints:
|
|
// Stmt after inlining:
|
|
// {
|
|
// for (const auto i : c10::irange(64)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// D[i, j] = i * (j + 1) - i;
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// We can also apply algebraic simplification to a statement:
|
|
StmtPtr simplified = IRSimplifier::simplify(nest.root_stmt());
|
|
std::cout << "Stmt after simplification:" << std::endl
|
|
<< *simplified << std::endl;
|
|
// Prints:
|
|
// Stmt after simplification:
|
|
// {
|
|
// for (const auto i : c10::irange(64)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// D[i, j] = i * j;
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// Many loopnest transformations are stateless and can be applied without
|
|
// creating a LoopNest object. In fact, we plan to make all transformations
|
|
// stateless.
|
|
// splitWithTail is one such transformation: it splits an iteration space
|
|
// of a given loop into two with a given factor.
|
|
ForPtr outer_loop = to<For>(to<Block>(simplified)->stmts().front());
|
|
LoopNest::splitWithTail(outer_loop, 13);
|
|
// Call simplifier once more to fold some arithmetic.
|
|
simplified = IRSimplifier::simplify(simplified);
|
|
std::cout << "Stmt after splitWithTail:" << std::endl
|
|
<< *simplified << std::endl;
|
|
// Prints:
|
|
// Stmt after splitWithTail:
|
|
// {
|
|
// for (const auto i_outer : c10::irange(4)) {
|
|
// for (const auto i_inner : c10::irange(13)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// D[i_inner + 13 * i_outer, j] = i_inner * j + 13 * (i_outer * j);
|
|
// }
|
|
// }
|
|
// }
|
|
// for (const auto i_tail : c10::irange(12)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// D[i_tail + 52, j] = i_tail * j + 52 * j;
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// NNC supports a wide range of loop nest transformations, which we are not
|
|
// listing here. Please refer to documentation in
|
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/tensorexpr/loopnest.h
|
|
// for more details.
|
|
}
|
|
|
|
std::cout << "*** Codegen ***" << std::endl;
|
|
{
|
|
// An ultimate goal of tensor expressions is to be provide a mechanism to
|
|
// execute a given computation in the fastest possible way. So far we've
|
|
// looked at how we could describe what computation we're interested in, but
|
|
// we haven't looked at how to actually execute it.
|
|
//
|
|
// All we've been dealing with was just symbols with no actual data
|
|
// associated, in this section we would look at how we can bridge that gap.
|
|
|
|
// Let's start by constructing a simple computation for us to work with:
|
|
BufHandle A("A", {64, 32}, kInt);
|
|
BufHandle B("B", {64, 32}, kInt);
|
|
Tensor X =
|
|
Compute("X", {64, 32}, [&](const VarHandle& i, const VarHandle& j) {
|
|
return A.load(i, j) + B.load(i, j);
|
|
});
|
|
|
|
// And let's lower it to a loop nest, as we did in the previous section. We
|
|
// can pass Tensor object directly:
|
|
LoopNest loopnest({X});
|
|
std::cout << *loopnest.root_stmt() << std::endl;
|
|
// Prints:
|
|
// {
|
|
// for (const auto i : c10::irange(64)) {
|
|
// for (const auto j : c10::irange(32)) {
|
|
// X[i, j] = (A[i, j]) + (B[i, j]);
|
|
// }
|
|
// }
|
|
|
|
// Now imagine that we have two actual tensors 64x32 that we want sum
|
|
// together, how do we pass those tensors to the computation and how do we
|
|
// carry it out?
|
|
//
|
|
// Codegen object is aimed at providing exactly that functionality. Codegen
|
|
// is an abstract class and concrete codegens are derived from it.
|
|
// Currently, we have three codegens:
|
|
// 1) Simple Evaluator,
|
|
// 2) LLVM Codegen for CPU,
|
|
// 3) CUDA Codegen.
|
|
// In this example we will be using Simple Evaluator, since it's available
|
|
// everywhere.
|
|
|
|
// To create a codegen, we need to provide the statement - it specifies the
|
|
// computation we want to perform - and a list of placeholders and tensors
|
|
// used in the computation. The latter part is crucial since that's the only
|
|
// way the codegen could use to correlate symbols in the statement to actual
|
|
// data arrays that we will be passing when we will actually be performing
|
|
// the computation.
|
|
//
|
|
// Let's create a Simple IR Evaluator codegen for our computation:
|
|
SimpleIREvaluator ir_eval(loopnest.root_stmt(), {A, B, X});
|
|
|
|
// We are using the simplest codegen and in it almost no work is done at the
|
|
// construction step. Real codegens such as CUDA and LLVM perform
|
|
// compilation during that stage so that when we're about to run the
|
|
// computation everything is ready.
|
|
|
|
// Let's now create some inputs and run our computation with them:
|
|
std::vector<int> data_A(64 * 32, 3); // This will be the input A
|
|
std::vector<int> data_B(64 * 32, 5); // This will be the input B
|
|
std::vector<int> data_X(64 * 32, 0); // This will be used for the result
|
|
|
|
// Now let's invoke our codegen to perform the computation on our data. We
|
|
// need to provide as many arguments as how many placeholders and tensors we
|
|
// passed at the codegen construction time. A position in these lists would
|
|
// define how real data arrays from the latter call (these arguments are
|
|
// referred to as 'CallArg's in our codebase) correspond to symbols
|
|
// (placeholders and tensors) used in the tensor expressions we constructed
|
|
// (these are referred to as 'BufferArg').
|
|
// Thus, we will provide three arguments: data_A, data_B, and data_X. data_A
|
|
// contains data for the placeholder A, data_B - for the placeholder B, and
|
|
// data_X would be used for contents of tensor X.
|
|
ir_eval(data_A, data_B, data_X);
|
|
|
|
// Let's print one of the elements from each array to verify that the
|
|
// computation did happen:
|
|
std::cout << "A[10] = " << data_A[10] << std::endl
|
|
<< "B[10] = " << data_B[10] << std::endl
|
|
<< "X[10] = A[10] + B[10] = " << data_X[10] << std::endl;
|
|
// Prints:
|
|
// A[10] = 3
|
|
// B[10] = 5
|
|
// X[10] = A[10] + B[10] = 8
|
|
}
|
|
|
|
std::cout << "*** Lowering TorchScript IR to TensorExpr IR ***" << std::endl;
|
|
{
|
|
// This section requires a LLVM-enabled PyTorch build, so we have to use a
|
|
// guard:
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
|
|
// Often we would like to convert a TorchScript IR to TE rather than
|
|
// construct TE IR from scratch. NNC provides an API to perform such
|
|
// lowering: it takes a TorchScript graph and returns an object that can be
|
|
// used to invoke the generated kernel.
|
|
// This API is currently used by the TorchScript JIT fuser and can also be
|
|
// used ahead of time to pre-compile parts of a model.
|
|
//
|
|
// To get familiar with this API let's first start with defining a simple
|
|
// TorchScript graph:
|
|
const auto graph_string = R"IR(
|
|
graph(%A : Float(5, 3, strides=[3, 1], device=cpu),
|
|
%B : Float(5, 3, strides=[3, 1], device=cpu)):
|
|
%AB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %B)
|
|
%one : int = prim::Constant[value=1]()
|
|
%AAB : Float(5, 3, strides=[3, 1]) = aten::mul(%A, %AB)
|
|
%AAB_plus_B: Float(5, 3, strides=[3, 1]) = aten::add(%AAB, %B, %one)
|
|
return (%AAB_plus_B))IR";
|
|
auto graph = std::make_shared<torch::jit::Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
// This graph defines a simple computation of A*A*B + B where A and B are
|
|
// input 5x3 tensors.
|
|
|
|
// To lower this TorchScript graph to TE, we just need to create a
|
|
// TensorExprKernel object. In its constructor it constructs the
|
|
// corresponding TE IR and compiles it for the given backend (in this
|
|
// example for CPU using LLVM compiler).
|
|
TensorExprKernel kernel(graph);
|
|
|
|
// We can retrieve the generated TE stmt from the kernel object:
|
|
StmtPtr kernel_stmt = kernel.getCodeGenStmt();
|
|
std::cout << "TE Stmt constructed from TorchScript: " << std::endl
|
|
<< *kernel_stmt << std::endl;
|
|
// Prints:
|
|
// TE Stmt constructed from TorchScript:
|
|
// {
|
|
// for (const auto v : c10::irange(5)) {
|
|
// for (const auto _tail_tail : c10::irange(3)) {
|
|
// aten_add[_tail_tail + 3 * v] = (tA[_tail_tail + 3 * v]) *
|
|
// ((tA[_tail_tail + 3 * v]) * (tB[_tail_tail + 3 * v])) +
|
|
// (tB[_tail_tail + 3 * v]);
|
|
// }
|
|
// }
|
|
// }
|
|
|
|
// We can also examine generated LLVM IR and assembly code:
|
|
std::cout << "Generated LLVM IR: " << std::endl;
|
|
auto ir_str = kernel.getCodeText("ir");
|
|
printLinesToFrom(ir_str, 15, 20);
|
|
// Prints:
|
|
// Generated LLVM IR:
|
|
// %9 = bitcast float* %2 to <8 x float>*
|
|
// %10 = load <8 x float>, <8 x float>* %9 ...
|
|
// %11 = bitcast float* %5 to <8 x float>*
|
|
// %12 = load <8 x float>, <8 x float>* %11 ...
|
|
// %13 = fmul <8 x float> %10, %12
|
|
// %14 = fmul <8 x float> %10, %13
|
|
|
|
std::cout << "Generated assembly: " << std::endl;
|
|
auto asm_str = kernel.getCodeText("asm");
|
|
printLinesToFrom(asm_str, 10, 15);
|
|
// Prints:
|
|
// Generated assembly:
|
|
// vmulps %ymm1, %ymm0, %ymm2
|
|
// vfmadd213ps %ymm1, %ymm0, %ymm2
|
|
// vmovups %ymm2, (%rax)
|
|
// vmovss 32(%rcx), %xmm0
|
|
// vmovss 32(%rdx), %xmm1
|
|
// vmulss %xmm1, %xmm0, %xmm2
|
|
|
|
// We can also execute the generated kernel:
|
|
auto A =
|
|
at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) *
|
|
2.0;
|
|
auto B =
|
|
at::ones({5, 3}, torch::TensorOptions(torch::kCPU).dtype(at::kFloat)) *
|
|
3.0;
|
|
std::vector<at::Tensor> inputs = {A, B};
|
|
std::vector<torch::IValue> stack = torch::fmap<torch::IValue>(inputs);
|
|
kernel.run(stack);
|
|
auto R = stack[0].toTensor();
|
|
|
|
// Let's print one of the elements from the result tensor to verify that the
|
|
// computation did happen and was correct:
|
|
std::cout << "R[2][2] = " << R[2][2] << std::endl;
|
|
// Prints:
|
|
// R[2][2] = 15
|
|
// [ CPUFloatType{} ]
|
|
#endif
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
void printLinesToFrom(const std::string& input_str, int from, int to) {
|
|
std::istringstream f(input_str);
|
|
std::string s;
|
|
int idx = 0;
|
|
while (getline(f, s)) {
|
|
if (idx > from) {
|
|
std::cout << s << "\n";
|
|
}
|
|
if (idx++ > to) {
|
|
break;
|
|
}
|
|
}
|
|
}
|