mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59423 Test Plan: Imported from OSS Reviewed By: huiguoo Differential Revision: D28886979 Pulled By: navahgar fbshipit-source-id: edfc61feaf5efe22d4f367ac718b83b3d0f47cb3
460 lines
15 KiB
C++
460 lines
15 KiB
C++
#pragma once
|
|
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace tensorexpr {
|
|
|
|
class Expr;
|
|
class Var;
|
|
class Buf;
|
|
class Tensor;
|
|
class Function;
|
|
class Stmt;
|
|
class For;
|
|
class Block;
|
|
class Store;
|
|
class Dtype;
|
|
|
|
class TORCH_API LoopNest {
|
|
public:
|
|
// A constructor for building a LoopNest from a list of Tensors
|
|
LoopNest(
|
|
const std::vector<Tensor*>& output_tensors,
|
|
const std::vector<Tensor*>& tensors_to_compute);
|
|
|
|
// A convenience constructor for the case when all tensors are output tensors
|
|
LoopNest(const std::vector<Tensor*>& output_tensors);
|
|
|
|
// A constructor for building a LoopNest from an Stmt and a list of output
|
|
// buffers.
|
|
LoopNest(Stmt* stmt, std::unordered_set<const Buf*> output_bufs);
|
|
|
|
// A constructor for building a LoopNest from another loopnest. It clones the
|
|
// other loopnest's stmt.
|
|
LoopNest(const LoopNest& other);
|
|
|
|
Stmt* root_stmt() const {
|
|
return root_stmt_;
|
|
}
|
|
|
|
std::vector<For*> getLoopStmtsFor(Tensor*) const;
|
|
std::vector<For*> getLoopStmtsFor(const Buf*) const;
|
|
std::vector<For*> getLoopStmtsFor(Stmt*) const;
|
|
Stmt* getLoopBodyFor(Tensor*) const;
|
|
Stmt* getLoopBodyFor(const Buf*) const;
|
|
|
|
// Returns the For stmt that is immediately enclosing the given stmt.
|
|
static For* getParentLoop(const Stmt* st);
|
|
|
|
// Returns the list of For stmts corresponding to the loopnest that is
|
|
// enclosing the given stmt.
|
|
static std::vector<For*> getEnclosingLoopNest(const Stmt* st);
|
|
|
|
// Returns a list of all Stmts that write to the given buf.
|
|
std::vector<const Stmt*> getAllWritesToBuf(const Buf*) const;
|
|
|
|
// The following methods return the For loops that contain writes to
|
|
// the given buf.
|
|
//
|
|
// For example, consider the following code:
|
|
// for i1
|
|
// for j1
|
|
// a[i1,j1] =
|
|
// for i2
|
|
// for j2
|
|
// for k2
|
|
// a[i2,j2] =
|
|
// for j3
|
|
// a[i2,j3] =
|
|
|
|
// Returns a list of For loops which directly contain a Stmt that writes
|
|
// to buf.
|
|
// For the above example:
|
|
// getAllInnermostLoopsWritingToBuf(a) => {j1, k2, j3}
|
|
std::vector<For*> getAllInnermostLoopsWritingToBuf(const Buf*) const;
|
|
|
|
// Returns a list of For loopnests which contain a Stmt that writes to
|
|
// the given buf. Each loopnest here is a vector For loops.
|
|
// For the above example:
|
|
// getAllLoopNestsWritingToBuf(a) => {{i1,j1}, {i2,j2,k2}, {i2,j3}}
|
|
std::vector<std::vector<For*>> getAllLoopNestsWritingToBuf(const Buf*) const;
|
|
|
|
Stmt* simplify();
|
|
|
|
bool computeInline(Stmt* s);
|
|
bool computeInline(const Buf* b);
|
|
void inlineIntermediateBufs(bool allow_duplicated_work);
|
|
|
|
// Optimizes conditionals.
|
|
//
|
|
// Currently, only the following pattern of conditionals is optimized.
|
|
// This corresponds to the conditional format that is generated to handle
|
|
// `aten::cat` op.
|
|
//
|
|
// for (int i = 0; i < 20; i++) {
|
|
// A[i] = IfThenElse(i<5 ? 1 : 0, B[i], C[i-5])
|
|
// }
|
|
//
|
|
// Constraints that must be satisfied for this optimization:
|
|
// * All conditions should be of the form "var < expr".
|
|
// * All conditions should have the same variable, say v.
|
|
// * The condition variable found should be the same as the inner-most
|
|
// loop variable. TODO: Remove this constraint.
|
|
// * If there are multiple stores that contain conditionals using the same
|
|
// loop variable, only the first conditional will be optimized.
|
|
// TODO: Remove this constraint.
|
|
bool optimizeConditionals();
|
|
|
|
// Splits the given loop into 2 nested loops with the given factor as the
|
|
// inner loop bound. If the factor does not evenly divide the loop bound,
|
|
// then the remainining iterations are extracted into a tail loop that is
|
|
// added after the given loop.
|
|
//
|
|
// For example, consider the following code:
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// A[i] =
|
|
// }
|
|
//
|
|
// splitWithTail(i, 8, ...) will result in:
|
|
// for (int i_outer = 0; i_outer < 12; ++i_outer) {
|
|
// for (int i_inner = 0; i_inner < 8; ++i_inner) {
|
|
// A[i_outer * 8 + i_inner] =
|
|
// }
|
|
// }
|
|
// for (int i_tail = 0; i_tail < 4; ++i_tail) {
|
|
// A[i_tail + 96] =
|
|
// }
|
|
//
|
|
// The given loop will be transformed to the outer loop after splitting.
|
|
// So, the pointer to the input loop should be valid after splitting and
|
|
// will point to the outer loop. The `inner` and `tail` parameters will be
|
|
// set to point to the inner and tail loops that are generated.
|
|
static void splitWithTail(For* f, int factor, For** inner, For** tail);
|
|
// A convenience wrapper when the caller does not need to access the
|
|
// split loops.
|
|
static void splitWithTail(For* f, int factor);
|
|
|
|
// Splits the given loop into 2 nested loops with the given factor as the
|
|
// inner loop bound. If the factor does not evenly divide the loop bound,
|
|
// then a conditional is inserted into the body to handle the remaining
|
|
// iterations appropriately.
|
|
//
|
|
// For example, consider the following code:
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// A[i] =
|
|
// }
|
|
//
|
|
// splitWithMask(i, 8, ...) will result in:
|
|
// for (int i_outer = 0; i_outer < 13; ++i_outer) {
|
|
// for (int i_inner = 0; i_inner < 8; ++i_inner) {
|
|
// if (i_outer * 8 + i_inner < 100) {
|
|
// A[i_outer * 8 + i_inner] =
|
|
// }
|
|
// }
|
|
// }
|
|
//
|
|
// The given loop will be transformed to the outer loop after splitting.
|
|
// So, the pointer to the input loop should be valid after splitting and
|
|
// will point to the outer loop. The `inner` parameter will be set to point
|
|
// to the inner loop that is generated.
|
|
static void splitWithMask(For* f, int factor, For** inner);
|
|
// A convenience wrapper when the caller does not need to access the
|
|
// split loops.
|
|
static void splitWithMask(For* f, int factor);
|
|
|
|
// The following methods support loop distribution.
|
|
// For example, consider the following code. This will be used to
|
|
// demonstrate the methods below.
|
|
//
|
|
// S1: for i
|
|
// S2: A[i] = 0
|
|
// S3: for j
|
|
// S4: A[i] = A[i] +
|
|
// S5: B[i] = A[i]
|
|
// S6: for k
|
|
// S7: B[i] = B[i] +
|
|
|
|
// This method distributes the given loop over its body by splitting
|
|
// after every given pivot stmt.
|
|
//
|
|
// NOTE: Pivot stmts that are not in the given loop's body will be ignored.
|
|
//
|
|
// For the above example:
|
|
// distributeLoop(S1, {S3, S5})
|
|
// will result in:
|
|
// S1: for i
|
|
// S2: A[i] = 0
|
|
// S3: for j
|
|
// S4: A[i] = A[i] +
|
|
// : for i
|
|
// S5: B[i] = A[i]
|
|
// : for i
|
|
// S6: for k
|
|
// S7: B[i] = B[i] +
|
|
static std::vector<For*> distributeLoop(
|
|
For* loop,
|
|
const std::unordered_set<Stmt*>& pivots);
|
|
|
|
// This method distributes the given loop over every stmt in its body.
|
|
//
|
|
// For the above example:
|
|
// distributeLoop(S1)
|
|
// will result in:
|
|
// S1: for i
|
|
// S2: A[i] = 0
|
|
// : for i
|
|
// S3: for j
|
|
// S4: A[i] = A[i] +
|
|
// : for i
|
|
// S5: B[i] = A[i]
|
|
// : for i
|
|
// S6: for k
|
|
// S7: B[i] = B[i] +
|
|
static std::vector<For*> distributeLoop(For* loop);
|
|
|
|
// This method distributes the given loop over its body by splitting
|
|
// after every For stmt in its body.
|
|
//
|
|
// For the above example:
|
|
// distributeLoopOverInnerLoops(S1)
|
|
// will result in:
|
|
// S1: for i
|
|
// S2: A[i] = 0
|
|
// S3: for j
|
|
// S4: A[i] = A[i] +
|
|
// : for i
|
|
// S5: B[i] = A[i]
|
|
// S6: for k
|
|
// S7: B[i] = B[i] +
|
|
static std::vector<For*> distributeLoopOverInnerLoops(For* loop);
|
|
|
|
// This method performs loop fusion.
|
|
// For example, consider the following code.
|
|
//
|
|
// S1: for m
|
|
// S2: A[m] = 0
|
|
// S3: for j
|
|
// S4: A[m] = A[m] +
|
|
// S5: for n
|
|
// S5: B[n] = A[n]
|
|
// S6: for k
|
|
// S7: B[n] = B[n] +
|
|
//
|
|
// fuseLoops({S1, S5}), will return the following loop:
|
|
// S1: for m
|
|
// S2: A[m] = 0
|
|
// S3: for j
|
|
// S4: A[m] = A[m] +
|
|
// S5: B[m] = A[m]
|
|
// S6: for k
|
|
// S7: B[m] = B[m] +
|
|
//
|
|
// Loop fusion is done only when all the conditions below are satisfied.
|
|
// * All the loops have the same parent.
|
|
// * There are no statements between these loops in their parent body.
|
|
// * The start bounds are the same for all loops.
|
|
// * The stop bounds are the same for all loops.
|
|
// * Fusing the loops does not violate or add any dependencies.
|
|
static bool fuseLoops(const std::vector<For*>& loops, For** fused);
|
|
|
|
void reorderAxis(For* a, For* b);
|
|
|
|
// Reorder the given list of loops according to the permutation specified.
|
|
// Here permutation[i] represents the location of the loop i in the result.
|
|
//
|
|
// For example, consider the following code:
|
|
// for p
|
|
// for q
|
|
// for r
|
|
// for s
|
|
// A[p,q,r,s] =
|
|
//
|
|
// reorder({p, q, r, s}, {2, 3, 0, 1}) will return the list of loops in the
|
|
// following form:
|
|
// for r
|
|
// for s
|
|
// for p
|
|
// for q
|
|
// A[p,q,r,s] =
|
|
static std::vector<For*> reorder(
|
|
const std::vector<For*>& loops,
|
|
const std::vector<size_t>& permutation);
|
|
|
|
// Returns true if the given loops are perfectly nested, i.e., every loop
|
|
// (except the innermost) should have exactly one statement in its body
|
|
// and that statement must be the next inner loop.
|
|
static bool areLoopsPerfectlyNested(const std::vector<For*>& loops);
|
|
|
|
// Returns true if the given loop has a loop-carried dependence.
|
|
static bool hasLoopCarriedDependence(For* loop);
|
|
|
|
static void unroll(For* f, Stmt** unrolled);
|
|
static void unroll(For* f);
|
|
|
|
static bool normalize(For* f);
|
|
static bool isNormalized(For* f);
|
|
|
|
static bool flatten(const std::vector<For*>& f, For** flattened);
|
|
static bool flatten(const std::vector<For*>& f);
|
|
|
|
// Compresses the given buffer based on its use in the given Stmts.
|
|
// For example, given the input:
|
|
//
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 200; ++j) {
|
|
// A[i,j] = sin(i*j)
|
|
// }
|
|
// for (int j = 0; j < 199; ++j) {
|
|
// B[i,j] = A[i,j] + A[i, j+1]
|
|
// }
|
|
// }
|
|
//
|
|
// compressBuffer(A, ...) will compress buffer A from
|
|
// [100, 200] to [1, 200] and modify the code as follows:
|
|
//
|
|
// for (int i = 0; i < 100; ++i) {
|
|
// for (int j = 0; j < 200; ++j) {
|
|
// A[0,j] = sin(i*j)
|
|
// }
|
|
// for (int j = 0; j < 199; ++j) {
|
|
// B[i,j] = A[0,j] + A[0, j+1]
|
|
// }
|
|
// }
|
|
static void compressBuffer(Buf* buf, Stmt* stmt);
|
|
|
|
// Get 'num' loops from the loopnest starting at 'f'.
|
|
static std::vector<For*> getLoopStmtsInLoopNest(For* f, size_t num);
|
|
|
|
// LoopOptions are propagated to tail.
|
|
void sliceHead(For* f, int factor, For** head, For** tail);
|
|
void sliceHead(For* f, int factor);
|
|
// LoopOptions are propagated to head.
|
|
void sliceTail(For* f, int factor, For** head, For** tail);
|
|
void sliceTail(For* f, int factor);
|
|
|
|
void setGPUBlockIndex(For* f, int idx);
|
|
void setGPUThreadIndex(For* f, int idx);
|
|
|
|
using AccessResult = std::pair<const Buf*, Stmt*>;
|
|
// Insert a cache for the consumer's usages of the buffer produced in
|
|
// consumer, and redirect reads and writes in the consumer to that cache.
|
|
// Returns a pair of the new cache buffer, and the new rewritten consumer.
|
|
AccessResult cacheAccesses(
|
|
const Buf* producer,
|
|
const std::string& name,
|
|
Stmt* consumer);
|
|
|
|
// Insert a temporary computation of statement S in the scope of loop AT.
|
|
// S is assumed to be a Store or a Block containing a Store. Along with the
|
|
// computation itself, this transformation inserts Alloc/Free statements for
|
|
// the temporary buffer used in the computation.
|
|
void computeAt(Stmt* s, For* at);
|
|
|
|
// Rfactor a reduction axis into a normal axis.
|
|
//
|
|
// Requirements:
|
|
// * S is the reduction store
|
|
// * S is the only statement in the innermost loop
|
|
// * There is at least two reduction arguments in S
|
|
// * OUTER_REDUCTION_FOR loop corresponds to the outermost reduction variable
|
|
// used in the store and all other reduction variables are index variables of
|
|
// children loops of OUTER_REDUCTION_FOR
|
|
// * OUTER_REDUCTION_FOR is a perfect loop nest, i.e. it has only loops
|
|
// corresponding to the other reduction variables and the store, nested into
|
|
// each other
|
|
//
|
|
// What it does:
|
|
// * Introduce a new buffer with an extra dimension of a size equal to the
|
|
// span of the loop OUTER_REDUCTION_FOR (the new buffer is returned via
|
|
// RFAC_BUF_PTR)
|
|
// * Insert an initialization store for the new buffer in
|
|
// OUTER_REDUCTION_FOR before its nested loop
|
|
// * Replace the reduction store to the original buffer with the reduction
|
|
// store to the temp buffer, removing the index var of OUTER_REDUCTION_FOR
|
|
// from reduction arguments
|
|
// * Insert a final reduction store over the extra dimension of the new
|
|
// buffer to the original buffer
|
|
// * Returns TRUE if the transformation succeeded and FALSE otherwise
|
|
//
|
|
// Example:
|
|
// Original IR:
|
|
// S1: for i # normal axis
|
|
// S2: X[i] = 0
|
|
// S3: for j # reduction axis
|
|
// S4: for k # reduction axis
|
|
// S5: X[i] = ReduceOp(X[i] + Y[i,j,k], reduce_axis={j,k})
|
|
//
|
|
// After RFACTOR(S5, S3)
|
|
// S1: for i # normal axis
|
|
// S2: X[i] = 0
|
|
// S3: for j # reduction axis for X, normal axis for X_rfac
|
|
// X_rfac[i,j] = 0
|
|
// S4: for k # reduction axis
|
|
// X_rfac[i,j] = ReduceOp(X_rfac[i,j] + Y[i,j,k], reduce_axis={k})
|
|
// X[i] = ReduceOp(X[i] + X_rfac[i,j], reduce_axis={j})
|
|
bool rfactor(Stmt* s, For* outer_reduction_for);
|
|
bool rfactor(Stmt* s, For* outer_reduction_for, Buf** rfac_buf_ptr);
|
|
|
|
// Vectorize the given loop. This method requires that the given loop
|
|
// does not perform a reduction.
|
|
// It returns true if vectorization is successful and false otherwise.
|
|
static bool vectorize(For*);
|
|
|
|
// Find the inner-most loops and vectorize them. Currently, this only works
|
|
// for the LLVM backend, when no reductions are involved.
|
|
void vectorizeInnerLoops();
|
|
|
|
void setBufferMap(
|
|
For* f,
|
|
const std::unordered_map<std::string, const Buf*>& map);
|
|
|
|
void eliminateDeadStores();
|
|
void prepareForCodegen();
|
|
|
|
const std::unordered_set<const Buf*> getInputBufs() const;
|
|
const std::unordered_set<const Buf*> getOutputBufs() const {
|
|
return output_bufs_;
|
|
}
|
|
|
|
private:
|
|
void initialize(
|
|
const std::vector<Tensor*>& output_tensors,
|
|
const std::vector<Tensor*>& tensors_to_compute);
|
|
Stmt* insertAllocFree(Stmt* stmt);
|
|
const std::unordered_set<const Buf*> getIntermediateBufs() const;
|
|
|
|
Stmt* root_stmt_;
|
|
|
|
std::unordered_set<const Buf*> output_bufs_;
|
|
};
|
|
|
|
TORCH_API Stmt* FlattenIndexes(Stmt* s);
|
|
|
|
// TODO: Revisit this once we decide on how dependencies analysis should look
|
|
// like. Maybe we would choose to use a different API and BufUse would be
|
|
// removed, or if we decide to keep it we need to properly document its API.
|
|
struct BufLoadOrStoreUse {
|
|
Stmt* s;
|
|
bool isStore;
|
|
};
|
|
|
|
/*
|
|
* Returns a map ( Buf -> uses of this Buf), uses are represented as vectors of
|
|
* BufUse elements, which are Stmt* and a bool isStore flag. The order of uses
|
|
* in the vectors reflects the order in which the uses appear in the given
|
|
* statement.
|
|
*/
|
|
std::unordered_map<const Buf*, std::vector<BufLoadOrStoreUse>>
|
|
findLoadOrStoreUses(Stmt* s);
|
|
|
|
} // namespace tensorexpr
|
|
} // namespace jit
|
|
} // namespace torch
|