mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/138974 Approved by: https://github.com/ezyang
3383 lines
102 KiB
C++
3383 lines
102 KiB
C++
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
||
|
||
#include <algorithm>
|
||
#include <iostream>
|
||
#include <stdexcept>
|
||
#include <unordered_map>
|
||
#include <unordered_set>
|
||
#include <utility>
|
||
#include <vector>
|
||
|
||
#include <c10/util/Logging.h>
|
||
#include <c10/util/irange.h>
|
||
|
||
#include <ATen/core/functional.h>
|
||
#include <torch/csrc/jit/jit_log.h>
|
||
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
||
#include <torch/csrc/jit/tensorexpr/bounds_inference.h>
|
||
#include <torch/csrc/jit/tensorexpr/eval.h>
|
||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||
#include <torch/csrc/jit/tensorexpr/ir.h>
|
||
#include <torch/csrc/jit/tensorexpr/ir_cloner.h>
|
||
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
|
||
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
||
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
||
#include <torch/csrc/jit/tensorexpr/ir_verifier.h>
|
||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||
|
||
namespace torch::jit::tensorexpr {
|
||
|
||
LoopNest::LoopNest(const LoopNest& other)
|
||
: root_stmt_(Stmt::clone(other.root_stmt_)),
|
||
output_bufs_(other.output_bufs_) {
|
||
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
|
||
verify(root_stmt_);
|
||
}
|
||
|
||
LoopNest::LoopNest(StmtPtr stmt, std::unordered_set<BufPtr> output_bufs)
|
||
: root_stmt_(std::move(stmt)), output_bufs_(std::move(output_bufs)) {
|
||
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
|
||
verify(root_stmt_);
|
||
}
|
||
|
||
LoopNest::LoopNest(
|
||
const std::vector<Tensor>& output_tensors,
|
||
const std::vector<Tensor>& tensors_to_compute) {
|
||
initialize(output_tensors, tensors_to_compute);
|
||
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
|
||
verify(root_stmt_);
|
||
}
|
||
|
||
LoopNest::LoopNest(const std::vector<Tensor>& output_tensors) {
|
||
initialize(output_tensors, output_tensors);
|
||
GRAPH_DEBUG("Origin Stmt in LoopNest:\n", std::to_string(root_stmt_));
|
||
verify(root_stmt_);
|
||
}
|
||
|
||
std::vector<BufPtr> LoopNest::getIntermediateBufs() const {
|
||
std::vector<BufPtr> result;
|
||
std::unordered_set<BufPtr> result_set;
|
||
auto input_bufs = getInputBufs();
|
||
auto bufs = NodeFinder<Buf>::find(root_stmt_);
|
||
for (const auto& buf : bufs) {
|
||
if (!output_bufs_.count(buf) && !input_bufs.count(buf) &&
|
||
!result_set.count(buf)) {
|
||
result.push_back(buf);
|
||
result_set.insert(buf);
|
||
}
|
||
}
|
||
return result;
|
||
}
|
||
|
||
const std::unordered_set<BufPtr> LoopNest::getInputBufs() const {
|
||
std::unordered_set<BufPtr> result;
|
||
auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
|
||
for (auto& kv : buf_load_store_uses) {
|
||
bool has_store = false;
|
||
for (auto& use : kv.second) {
|
||
if (use.isStore) {
|
||
has_store = true;
|
||
break;
|
||
}
|
||
}
|
||
if (!has_store) {
|
||
result.insert(kv.first);
|
||
}
|
||
}
|
||
return result;
|
||
}
|
||
|
||
class IndexFlattener : public IRMutator {
|
||
public:
|
||
StmtPtr flatten(const StmtPtr& s) {
|
||
return s->accept_mutator(this);
|
||
}
|
||
|
||
ExprPtr mutate(const LoadPtr& v) override {
|
||
if (v->indices().size() == 1) {
|
||
return v;
|
||
}
|
||
return alloc<Load>(
|
||
v->dtype(),
|
||
v->buf(),
|
||
std::vector<ExprPtr>({flatten_index(
|
||
v->buf()->dims(), v->indices(), v->buf()->strides())}));
|
||
}
|
||
|
||
StmtPtr mutate(const StorePtr& v) override {
|
||
ExprPtr value = v->value();
|
||
ExprPtr new_value = value->accept_mutator(this);
|
||
if (v->indices().size() == 1 && value == new_value) {
|
||
return v;
|
||
}
|
||
std::vector<ExprPtr> indices = {
|
||
flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides())};
|
||
v->set_indices(indices);
|
||
v->set_value(new_value);
|
||
return v;
|
||
}
|
||
};
|
||
|
||
static bool isValidIdentifierChar(char c, size_t pos) {
|
||
return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
|
||
}
|
||
|
||
// replaces all invalid characters with underscore
|
||
std::string sanitizeName(const std::string& input_name) {
|
||
std::stringstream sanitized_name;
|
||
for (size_t i = 0; i < input_name.size(); ++i) {
|
||
if (isValidIdentifierChar(input_name[i], i)) {
|
||
sanitized_name << input_name[i];
|
||
} else {
|
||
if (i == 0) {
|
||
// Don't start names with underscore
|
||
sanitized_name << "v";
|
||
}
|
||
sanitized_name << "_";
|
||
}
|
||
}
|
||
return sanitized_name.str();
|
||
}
|
||
|
||
class VarNameSanitizer : public IRMutator {
|
||
public:
|
||
ExprPtr mutate(const BufPtr& v) override {
|
||
if (seen_bufs_.count(v)) {
|
||
return v;
|
||
}
|
||
const std::string& name = v->name_hint();
|
||
auto new_name = sanitizeName(name);
|
||
if (taken_names_.count(new_name)) {
|
||
new_name = getNextAvailableName(new_name);
|
||
}
|
||
v->set_name_hint(new_name);
|
||
taken_names_.insert(new_name);
|
||
seen_bufs_.insert(v);
|
||
return v;
|
||
}
|
||
|
||
ExprPtr mutate(const VarPtr& v) override {
|
||
if (seen_vars_.count(v)) {
|
||
return v;
|
||
}
|
||
const std::string& name = v->name_hint();
|
||
auto new_name = sanitizeName(name);
|
||
if (taken_names_.count(new_name)) {
|
||
new_name = getNextAvailableName(new_name);
|
||
}
|
||
v->set_name_hint(new_name);
|
||
taken_names_.insert(new_name);
|
||
seen_vars_.insert(v);
|
||
return v;
|
||
}
|
||
|
||
StmtPtr mutate(const ForPtr& v) override {
|
||
auto new_name = getNextAvailableName(getIndexVarNameAtLevel(level_));
|
||
if (seen_index_vars_.count(v->var())) {
|
||
auto new_var = alloc<Var>("", v->var()->dtype());
|
||
Substitute(v, {{v->var(), new_var}});
|
||
}
|
||
v->var()->set_name_hint(new_name);
|
||
seen_index_vars_.insert(v->var());
|
||
seen_vars_.insert(v->var());
|
||
taken_names_.insert(new_name);
|
||
level_++;
|
||
v->body()->accept_mutator(this);
|
||
level_--;
|
||
v->start()->accept_mutator(this);
|
||
v->stop()->accept_mutator(this);
|
||
return v;
|
||
}
|
||
|
||
std::string getIndexVarNameAtLevel(int level_) {
|
||
auto names_num = index_var_names_.size();
|
||
auto counter = level_ / names_num;
|
||
if (counter == 0) {
|
||
return index_var_names_[level_ % names_num];
|
||
} else {
|
||
return index_var_names_[level_ % names_num] + std::to_string(counter);
|
||
}
|
||
}
|
||
std::string getNextAvailableName(const std::string& base_name) {
|
||
std::string name = base_name;
|
||
int counter = 0;
|
||
while (taken_names_.count(name)) {
|
||
counter++;
|
||
name = base_name + "_" + std::to_string(counter);
|
||
}
|
||
return name;
|
||
}
|
||
|
||
private:
|
||
std::vector<std::string> index_var_names_ =
|
||
{"i", "j", "k", "l", "m", "n", "o", "p"};
|
||
std::unordered_set<std::string> taken_names_;
|
||
std::unordered_set<VarPtr> seen_index_vars_;
|
||
std::unordered_set<VarPtr> seen_vars_;
|
||
std::unordered_set<BufPtr> seen_bufs_;
|
||
int level_ = 0;
|
||
};
|
||
|
||
StmtPtr LoopNest::sanitizeNames(StmtPtr s) {
|
||
VarNameSanitizer r;
|
||
s->accept_mutator(&r);
|
||
return s;
|
||
}
|
||
|
||
class Vectorizer : public IRMutator {
|
||
public:
|
||
StmtPtr vectorize(ForPtr v) {
|
||
StmtPtr body = v->body();
|
||
VarPtr var = v->var();
|
||
ExprPtr start = v->start();
|
||
ExprPtr stop = v->stop();
|
||
|
||
auto start_imm = intValue(start);
|
||
auto stop_imm = intValue(stop);
|
||
if (!start_imm) {
|
||
// Can't vectorize due to non-constant loop start!
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
|
||
if (!stop_imm) {
|
||
// Can't vectorize due to non-constant loop stop!
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
|
||
var_ = var;
|
||
start_ = immLike(start, *start_imm);
|
||
lanes_ = *stop_imm;
|
||
|
||
StmtPtr new_body = body->accept_mutator(this);
|
||
if (new_body == body) {
|
||
// Vectorization failed!
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
|
||
return new_body;
|
||
}
|
||
|
||
bool success() const {
|
||
return success_;
|
||
}
|
||
|
||
ExprPtr mutate(const AddPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) + ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const SubPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) - ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const MulPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) * ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const DivPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) / ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const ModPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) % ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const AndPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) & ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const OrPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) | ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const XorPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) ^ ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const LshiftPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) << ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const RshiftPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(inputs[0]) >> ExprHandle(inputs[1]);
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const MaxPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return Max::make(
|
||
ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans());
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const MinPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return Min::make(
|
||
ExprHandle(inputs[0]), ExprHandle(inputs[1]), v->propagate_nans());
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const CompareSelectPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {
|
||
v->lhs(), v->rhs(), v->ret_val1(), v->ret_val2()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return CompareSelect::make(
|
||
ExprHandle(inputs[0]),
|
||
ExprHandle(inputs[1]),
|
||
ExprHandle(inputs[2]),
|
||
ExprHandle(inputs[3]),
|
||
v->compare_select_op(),
|
||
v->bias());
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const BitCastPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->src_value()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return BitCast::make(
|
||
Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0]));
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const CastPtr& v) override {
|
||
std::vector<ExprPtr> inputs = {v->src_value()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return Cast::make(
|
||
Dtype(v->dtype().scalar_type(), lanes_), ExprHandle(inputs[0]));
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const VarPtr& v) override {
|
||
if (v == var_) {
|
||
return Ramp::make(
|
||
ExprHandle(start_), ExprHandle(immLike(start_, 1)), lanes_)
|
||
.node();
|
||
}
|
||
|
||
return v;
|
||
}
|
||
|
||
ExprPtr mutate(const RampPtr& v) override {
|
||
ExprPtr base = v->base();
|
||
ExprPtr stride = v->stride();
|
||
|
||
ExprPtr base_new = base->accept_mutator(this);
|
||
ExprPtr stride_new = stride->accept_mutator(this);
|
||
|
||
if (base_new == base && stride_new == stride) {
|
||
return v;
|
||
}
|
||
|
||
// Can't vectorize a Ramp!
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
|
||
ExprPtr mutate(const LoadPtr& v) override {
|
||
Dtype dtype(v->dtype().scalar_type(), lanes_);
|
||
BufPtr buf = v->buf();
|
||
std::vector<ExprPtr> inputs = {v->flat_index()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return Load::make(dtype, BufHandle(buf), {ExprHandle(inputs[0])});
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const ReduceOpPtr& v) override {
|
||
Dtype dtype(v->dtype().scalar_type(), lanes_);
|
||
|
||
std::vector<ExprPtr> inputs = {v->body()};
|
||
|
||
auto out = try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(
|
||
alloc<ReduceOp>(inputs[0], v->reduce_args(), v->reducer()));
|
||
});
|
||
return out;
|
||
}
|
||
|
||
ExprPtr mutate(const BroadcastPtr& v) override {
|
||
ExprPtr val = v->value();
|
||
ExprPtr new_val = val->accept_mutator(this);
|
||
if (new_val == val) {
|
||
return v;
|
||
}
|
||
|
||
// Can't vectorize a Broadcast!
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
|
||
ExprPtr mutate(const IfThenElsePtr& v) override {
|
||
ExprPtr condition = v->condition();
|
||
ExprPtr new_condition = condition->accept_mutator(this);
|
||
if (new_condition != condition) {
|
||
// Can't vectorize an IfThenElse condition!
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
|
||
std::vector<ExprPtr> inputs = {v->true_value(), v->false_value()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return IfThenElse::make(
|
||
ExprHandle(condition), ExprHandle(inputs[0]), ExprHandle(inputs[1]));
|
||
});
|
||
}
|
||
|
||
ExprPtr mutate(const IntrinsicsPtr& v) override {
|
||
std::vector<ExprPtr> inputs = v->params();
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return ExprHandle(alloc<Intrinsics>(v->op_type(), inputs));
|
||
});
|
||
}
|
||
|
||
StmtPtr mutate(const StorePtr& v) override {
|
||
BufPtr buf = v->buf();
|
||
std::vector<ExprPtr> inputs = {v->flat_index(), v->value()};
|
||
return try_vectorize(v, inputs, [&]() {
|
||
return Store::make(
|
||
BufHandle(buf), {ExprHandle(inputs[0])}, ExprHandle(inputs[1]));
|
||
});
|
||
}
|
||
|
||
StmtPtr mutate(const ForPtr& v) override {
|
||
VarPtr var = v->var();
|
||
ExprPtr start = v->start();
|
||
ExprPtr stop = v->stop();
|
||
LoopOptions loop_options = v->loop_options();
|
||
|
||
ExprPtr new_start = start->accept_mutator(this);
|
||
ExprPtr new_stop = stop->accept_mutator(this);
|
||
|
||
if (new_start != start || new_stop != stop) {
|
||
// Can't vectorize nested For with dependent loop bounds!
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
|
||
StmtPtr body = v->body();
|
||
StmtPtr new_body = body->accept_mutator(this);
|
||
|
||
if (new_body == body) {
|
||
return (ForPtr)v;
|
||
}
|
||
|
||
return alloc<For>(var, new_start, new_stop, new_body, loop_options);
|
||
}
|
||
|
||
StmtPtr mutate(const BlockPtr& v) override {
|
||
// IRMutator does in-place mutations. But the logic in vectorization checks
|
||
// for success by looking for a new stmt. So, we override the in-place
|
||
// mutations and create a clone here if any of its statements change.
|
||
// TODO: Can we change the logic of vectorizer so that we don't need this?
|
||
bool any_change = false;
|
||
std::vector<StmtPtr> stmts;
|
||
for (const StmtPtr& stmt : *v) {
|
||
StmtPtr stmt_new = stmt->accept_mutator(this);
|
||
if (stmt != stmt_new) {
|
||
any_change = true;
|
||
} else {
|
||
stmt_new = Stmt::clone(stmt);
|
||
}
|
||
if (stmt_new) {
|
||
stmts.push_back(stmt_new);
|
||
}
|
||
}
|
||
if (any_change) {
|
||
return alloc<Block>(stmts);
|
||
}
|
||
return v;
|
||
}
|
||
|
||
template <typename T>
|
||
ExprPtr try_vectorize(ExprPtr e, std::vector<ExprPtr>& inputs, T&& vec_ctor) {
|
||
bool vectorize = vectorize_inputs(inputs);
|
||
if (vectorize) {
|
||
return vec_ctor().node();
|
||
}
|
||
|
||
return e;
|
||
}
|
||
|
||
template <typename T>
|
||
StmtPtr try_vectorize(StmtPtr s, std::vector<ExprPtr>& inputs, T&& vec_ctor) {
|
||
bool vectorize = vectorize_inputs(inputs);
|
||
if (vectorize) {
|
||
return vec_ctor();
|
||
}
|
||
|
||
return s;
|
||
}
|
||
|
||
bool vectorize_inputs(std::vector<ExprPtr>& inputs) {
|
||
bool any_vectorized = false;
|
||
std::vector<ExprPtr> new_inputs;
|
||
|
||
// Attempt to vectorize each input.
|
||
for (ExprPtr& in : inputs) {
|
||
ExprPtr new_in = in->accept_mutator(this);
|
||
new_inputs.push_back(new_in);
|
||
if (new_in != in) {
|
||
any_vectorized = true;
|
||
}
|
||
}
|
||
|
||
// If none of them vectorized, then don't vectorize this.
|
||
if (!any_vectorized) {
|
||
return false;
|
||
}
|
||
|
||
// Insert broadcasts for any inputs that weren't vectorized.
|
||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||
if (inputs[i] == new_inputs[i]) {
|
||
inputs[i] = Broadcast::make(ExprHandle(inputs[i]), lanes_).node();
|
||
} else {
|
||
inputs[i] = new_inputs[i];
|
||
}
|
||
}
|
||
|
||
// And then vectorize this node.
|
||
return true;
|
||
}
|
||
|
||
VarPtr var_ = nullptr;
|
||
int64_t lanes_ = 0;
|
||
ExprPtr start_ = nullptr;
|
||
bool success_ = true;
|
||
};
|
||
|
||
bool LoopNest::vectorize(const ForPtr& f) {
|
||
BlockPtr b = to<Block>(f->get_parent());
|
||
if (!b) {
|
||
return false;
|
||
}
|
||
|
||
// Can't vectorize reduction axes.
|
||
auto reductions = NodeFinder<ReduceOp>::find(f);
|
||
for (const auto& r : reductions) {
|
||
if (std::find(r->reduce_args().begin(), r->reduce_args().end(), f->var()) !=
|
||
r->reduce_args().end()) {
|
||
return false;
|
||
}
|
||
}
|
||
|
||
Vectorizer v;
|
||
StmtPtr new_f = nullptr;
|
||
new_f = Stmt::clone(f);
|
||
normalize(to<For>(new_f));
|
||
new_f = FlattenIndexes(new_f);
|
||
new_f = v.vectorize(to<For>(new_f));
|
||
if (!v.success()) {
|
||
// We clone f before vectorizing. So, any partial vectorization will
|
||
// have modified the clone. In case of an exception, we can continue
|
||
// using f.
|
||
new_f = f;
|
||
}
|
||
|
||
if (new_f != f) {
|
||
b->replace_stmt(f, IRSimplifier::simplify(new_f));
|
||
return true;
|
||
}
|
||
|
||
// Vectorization was not successful.
|
||
return false;
|
||
}
|
||
|
||
void LoopNest::initialize(
|
||
const std::vector<Tensor>& output_tensors,
|
||
const std::vector<Tensor>& tensors_to_compute) {
|
||
for (const auto& t : output_tensors) {
|
||
output_bufs_.insert(t.buf());
|
||
}
|
||
|
||
std::vector<StmtPtr> loops;
|
||
for (const Tensor& t : tensors_to_compute) {
|
||
StmtPtr loop = t.stmt();
|
||
if (loop->get_parent()) {
|
||
std::cerr << "Error: creating a loopnest from already used Tensors\n";
|
||
loops = {};
|
||
break;
|
||
}
|
||
// Flatten initializers.
|
||
if (BlockPtr block = to<Block>(loop)) {
|
||
for (const auto& s : block->stmts()) {
|
||
block->remove_stmt(s);
|
||
loops.push_back(s);
|
||
}
|
||
} else {
|
||
loops.push_back(loop);
|
||
}
|
||
}
|
||
|
||
root_stmt_ = alloc<Block>(loops);
|
||
}
|
||
|
||
class FunctionInliner : public IRMutator {
|
||
public:
|
||
FunctionInliner(StorePtr producer, std::unordered_set<BufPtr> outputs)
|
||
: buf_(producer->buf()),
|
||
producer_(std::move(producer)),
|
||
outputs_(std::move(outputs)) {
|
||
for (const auto& i : producer_->indices()) {
|
||
if (auto index_var = to<Var>(i)) {
|
||
index_vars_.insert(index_var);
|
||
producer_index_vars_.push_back(index_var);
|
||
} else {
|
||
// If the index can be a constant, then that dimension must have size 1
|
||
// (since we don't support in-place writes). Resolves issue 52581.
|
||
auto index_val = evalInt(i);
|
||
if (!index_val || *index_val != 0) {
|
||
success_ = false;
|
||
break;
|
||
}
|
||
producer_index_vars_.push_back(nullptr);
|
||
}
|
||
}
|
||
}
|
||
|
||
bool success() const {
|
||
return success_;
|
||
}
|
||
|
||
private:
|
||
ExprPtr mutate_loads(const BufPtr& buf, std::vector<ExprPtr> dims) {
|
||
std::vector<VarPtr> index_vars;
|
||
if (buf->ndim() != producer_index_vars_.size()) {
|
||
// Dimensions of producer and consumer expressions do not match in inliner
|
||
// in the fuser
|
||
success_ = false;
|
||
return nullptr;
|
||
}
|
||
for (const auto i : c10::irange(buf->ndim())) {
|
||
VarPtr func_callee_arg = producer_index_vars_.at(i);
|
||
ExprPtr func_caller_param = dims.at(i);
|
||
if (func_callee_arg == nullptr) {
|
||
continue;
|
||
}
|
||
auto iter = inline_mapping_.find(func_callee_arg);
|
||
if (iter != inline_mapping_.end()) {
|
||
// Duplicated variables
|
||
success_ = false;
|
||
return nullptr;
|
||
}
|
||
// Add a mapping for each function parameter to it's source name.
|
||
inline_mapping_[func_callee_arg] = func_caller_param;
|
||
GRAPH_DEBUG(
|
||
"ComputeInline: Inline mapping: ",
|
||
std::to_string(func_callee_arg),
|
||
" -> ",
|
||
std::to_string(func_caller_param));
|
||
index_vars.push_back(func_callee_arg);
|
||
}
|
||
|
||
// Call the actual replacement.
|
||
ExprPtr body = producer_->value();
|
||
GRAPH_DEBUG("ComputeInline: Before rewriting body: ", std::to_string(body));
|
||
ExprPtr result = Expr::clone(body)->accept_mutator(this);
|
||
GRAPH_DEBUG(
|
||
"ComputeInline: After rewriting body: ", std::to_string(result));
|
||
|
||
// Remove the mappings we created for this function parameters.
|
||
for (const auto& v : index_vars) {
|
||
for (auto& pair : random_bindings_) {
|
||
if (pair.second.erase(v)) {
|
||
ExprPtr inlined = inline_mapping_[v];
|
||
for (const auto& nv : VarFinder::find(inlined)) {
|
||
pair.second.insert(nv);
|
||
}
|
||
}
|
||
}
|
||
GRAPH_DEBUG("ComputeInline: Inline mapping: erasing", std::to_string(v));
|
||
inline_mapping_.erase(v);
|
||
}
|
||
return result;
|
||
}
|
||
|
||
ExprPtr mutate(const LoadPtr& v) override {
|
||
if (!success()) {
|
||
return v;
|
||
}
|
||
BufPtr buf = v->buf();
|
||
if (buf != buf_) {
|
||
return IRMutator::mutate(v);
|
||
}
|
||
|
||
if (v->indices().size() != buf->ndim()) {
|
||
// Number of indices doesn't match buf rank in the fuser
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
auto result = mutate_loads(buf, v->indices());
|
||
if (!result) {
|
||
// If we don't inline successfully return the given load.
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
return result;
|
||
}
|
||
|
||
// Replace the target variable with the caller expressions.
|
||
ExprPtr mutate(const VarPtr& v) override {
|
||
if (!success()) {
|
||
return v;
|
||
}
|
||
auto iter = inline_mapping_.find(v);
|
||
if (iter == inline_mapping_.end()) {
|
||
return v;
|
||
} else {
|
||
ExprPtr expr = iter->second;
|
||
// Continue to transform the value from the lookup table.
|
||
return expr->accept_mutator(this);
|
||
}
|
||
}
|
||
|
||
// Handle random intrinsics which should be cached.
|
||
ExprPtr mutate(const IntrinsicsPtr& v) override {
|
||
if (!success()) {
|
||
return v;
|
||
}
|
||
if (!in_producer_ || v->op_type() != kRand) {
|
||
return IRMutator::mutate(v);
|
||
}
|
||
|
||
// Create a new Let Statement for the random variable, which we can refer
|
||
// to multiple times and resolve the same value (ie. store it in a scalar
|
||
// rather than the Tensor).
|
||
const std::string& name = buf_->name_hint();
|
||
VarPtr new_var = alloc<Var>(name, v->dtype());
|
||
random_bindings_[alloc<Let>(new_var, v)] = index_vars_;
|
||
GRAPH_DEBUG(
|
||
"ComputeInline: created random bindings for ", std::to_string(new_var));
|
||
return new_var;
|
||
}
|
||
|
||
// Remove the buffer write from the inlined function.
|
||
StmtPtr mutate(const StorePtr& v) override {
|
||
if (!success()) {
|
||
return v;
|
||
}
|
||
// If the buf_ is in the outputs set, keep its statement intact. Otherwise,
|
||
// remove it.
|
||
if (v == producer_ && !outputs_.count(buf_)) {
|
||
in_producer_ = true;
|
||
producer_ = to<Store>(IRMutator::mutate(v));
|
||
if (!producer_) {
|
||
// Producer statement for output buf should remain non-null in the fuser
|
||
success_ = false;
|
||
return v;
|
||
}
|
||
in_producer_ = false;
|
||
return nullptr;
|
||
} else {
|
||
return IRMutator::mutate(v);
|
||
}
|
||
}
|
||
|
||
// Any Random Intrinsics that were turned into vars must be inserted here.
|
||
StmtPtr mutate(const BlockPtr& v) override {
|
||
if (!success()) {
|
||
return v;
|
||
}
|
||
std::vector<StmtPtr> stmts;
|
||
for (const StmtPtr& stmt : *v) {
|
||
StmtPtr stmt_new = stmt->accept_mutator(this);
|
||
if (!stmt_new) {
|
||
continue;
|
||
}
|
||
|
||
if (stmt == stmt_new) {
|
||
stmt_new = Stmt::clone(stmt);
|
||
}
|
||
|
||
stmts.push_back(stmt_new);
|
||
}
|
||
|
||
return Block::make(stmts);
|
||
}
|
||
|
||
StmtPtr mutate(const ForPtr& v) override {
|
||
if (!success()) {
|
||
return v;
|
||
}
|
||
ForPtr res = to<For>(IRMutator::mutate(v));
|
||
if (!res) {
|
||
return nullptr;
|
||
}
|
||
|
||
// Find any random bindings that should be defined in this loops body.
|
||
std::vector<LetPtr> bindings_this_loop;
|
||
VarPtr fv = v->var();
|
||
for (auto& pair : random_bindings_) {
|
||
auto& index_var = pair.second;
|
||
if (index_var.erase(fv)) {
|
||
bindings_this_loop.push_back(pair.first);
|
||
}
|
||
}
|
||
|
||
for (const auto& l : bindings_this_loop) {
|
||
res->body()->prepend_stmt(l);
|
||
random_bindings_.erase(l);
|
||
}
|
||
return res;
|
||
}
|
||
|
||
private:
|
||
BufPtr buf_;
|
||
StorePtr producer_;
|
||
|
||
// Index Vars present in the producer.
|
||
std::unordered_set<VarPtr> index_vars_;
|
||
std::vector<VarPtr> producer_index_vars_;
|
||
|
||
std::unordered_map<VarPtr, ExprPtr> inline_mapping_;
|
||
|
||
// In the producer's scope - we need to bind any calls to rand().
|
||
bool in_producer_ = false;
|
||
std::unordered_map<LetPtr, std::unordered_set<VarPtr>> random_bindings_;
|
||
std::unordered_set<BufPtr> outputs_;
|
||
bool success_ = true;
|
||
};
|
||
|
||
static StmtPtr computeInlineImpl(
|
||
const BufPtr& b,
|
||
const StmtPtr& stmt,
|
||
const std::unordered_set<BufPtr>& output_bufs) {
|
||
// If buf is used or defined in an ExternalCall, we cannot inline it
|
||
auto buf_load_store_uses = findLoadOrStoreUses(stmt);
|
||
if (!buf_load_store_uses.count(b)) {
|
||
return nullptr;
|
||
}
|
||
for (auto& use : buf_load_store_uses.at(b)) {
|
||
StmtPtr s = use.s;
|
||
if (to<ExternalCall>(s) || to<ExternalCallWithAlloc>(s)) {
|
||
return nullptr;
|
||
}
|
||
}
|
||
|
||
// Find producers.
|
||
StorePtr relevant_store{nullptr};
|
||
auto stores = NodeFinder<Store>::find(stmt);
|
||
for (const auto& s : stores) {
|
||
if (s->buf() == b) {
|
||
auto reductions = NodeFinder<ReduceOp>::find(s);
|
||
if (!reductions.empty()) {
|
||
// Cannot inline a reduction computation
|
||
return nullptr;
|
||
}
|
||
if (relevant_store != nullptr) {
|
||
// Cannot inline Buf with multiple Tensors
|
||
return nullptr;
|
||
}
|
||
relevant_store = s;
|
||
}
|
||
}
|
||
|
||
if (!relevant_store) {
|
||
// Cannot find a relevant store to inline a buf in the fuser
|
||
return nullptr;
|
||
}
|
||
|
||
GRAPH_DEBUG("ComputeInline: Def: ", std::to_string(relevant_store));
|
||
FunctionInliner inliner(relevant_store, output_bufs);
|
||
auto result = stmt->accept_mutator(&inliner);
|
||
if (inliner.success()) {
|
||
return result;
|
||
}
|
||
return nullptr;
|
||
}
|
||
|
||
bool LoopNest::computeInline(const BufPtr& b) {
|
||
// Inlining may not always be successful. Since all mutations now happen
|
||
// in-place, an unsuccessful inlining transformation might leave the IR
|
||
// in an invalid state. To get around this problem, we clone the root stmt,
|
||
// try inlining on the clone, and if it succeeds, we proceed to perform
|
||
// inlining on the actual root stmt. This way the root stmt will always be
|
||
// in a valid state.
|
||
auto stmt_copy = Stmt::clone(root_stmt_);
|
||
auto try_inline = computeInlineImpl(b, stmt_copy, output_bufs_);
|
||
if (!try_inline) {
|
||
return false;
|
||
}
|
||
root_stmt_ = computeInlineImpl(b, root_stmt_, output_bufs_);
|
||
return true;
|
||
}
|
||
|
||
bool LoopNest::computeInline(const StmtPtr& s) {
|
||
auto s_store = to<Store>(s);
|
||
if (s_store == nullptr) {
|
||
// Could not find buffer producer to inline
|
||
return false;
|
||
}
|
||
return computeInline(s_store->buf());
|
||
}
|
||
|
||
// inlining buffers with multiple uses can create duplicated work, which can
|
||
// slow down cpu code generation but is enabled on gpu because it avoids
|
||
// difficult synchronization logic across blocks. Inlining trivial reads does
|
||
// not duplicate work
|
||
void LoopNest::inlineIntermediateBufs(bool allow_duplicated_work) {
|
||
std::unordered_set<BufPtr> bufs_to_inline;
|
||
|
||
auto intermediate_bufs = getIntermediateBufs();
|
||
if (allow_duplicated_work) {
|
||
bufs_to_inline.insert(intermediate_bufs.begin(), intermediate_bufs.end());
|
||
} else {
|
||
auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
|
||
auto input_bufs = getInputBufs();
|
||
|
||
for (const auto& buf : intermediate_bufs) {
|
||
TORCH_INTERNAL_ASSERT(
|
||
buf_load_store_uses.count(buf),
|
||
buildErrorMessage(
|
||
"Could not find uses of buf '" + buf->name_hint() +
|
||
"' in the fuser."));
|
||
std::vector<BufLoadOrStoreUse>& uses = buf_load_store_uses[buf];
|
||
auto stores = c10::filter(
|
||
uses, [](const BufLoadOrStoreUse& use) { return use.isStore; });
|
||
|
||
// if the intermediate is the buffer formed from reading in the input
|
||
// tensors, always inline, bc we are not duplicating any work
|
||
// and avoiding an intermediary buffer
|
||
if (stores.size() == 1) {
|
||
if (auto store = to<Store>(stores[0].s)) {
|
||
auto input_as_load = to<Load>(store->value());
|
||
if (input_as_load && input_bufs.count(input_as_load->buf())) {
|
||
bufs_to_inline.insert(buf);
|
||
continue;
|
||
}
|
||
} else {
|
||
// If S is not a store, it must be an ExternalCall.
|
||
TORCH_INTERNAL_ASSERT(
|
||
to<ExternalCall>(stores[0].s) ||
|
||
to<ExternalCallWithAlloc>(stores[0].s),
|
||
buildErrorMessage(
|
||
"Expected stmt: " + std::to_string(stores[0].s) +
|
||
"\nto be either a Store or an ExternalCall in the fuser."));
|
||
}
|
||
}
|
||
|
||
// all bufs will have at least one store (if they have > 1 they cant be
|
||
// inlined anyway)
|
||
size_t reads = uses.size() - 1;
|
||
// if only one read, we can inline it without duplicating work
|
||
if (reads <= 1) {
|
||
bufs_to_inline.insert(buf);
|
||
}
|
||
}
|
||
}
|
||
|
||
if (allow_duplicated_work) {
|
||
bufs_to_inline.insert(output_bufs_.begin(), output_bufs_.end());
|
||
}
|
||
|
||
for (const auto& b : bufs_to_inline) {
|
||
computeInline(b);
|
||
}
|
||
}
|
||
|
||
// TODO: Unify with DepTracker
|
||
class LoadOrStoreUseFinder : public IRVisitor {
|
||
public:
|
||
std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findUses(
|
||
const StmtPtr& s) {
|
||
uses_.clear();
|
||
s->accept(this);
|
||
return uses_;
|
||
}
|
||
|
||
private:
|
||
void visit(const StorePtr& v) override {
|
||
if (stores_[v->buf()].insert(last_stmt_).second) {
|
||
uses_[v->buf()].push_back({(StmtPtr)v, true});
|
||
}
|
||
last_stmt_ = (StmtPtr)v;
|
||
IRVisitor::visit(v);
|
||
}
|
||
|
||
void visit(const ExternalCallPtr& v) override {
|
||
if (stores_[v->buf()].insert(last_stmt_).second) {
|
||
uses_[v->buf()].push_back({(StmtPtr)v, true});
|
||
}
|
||
last_stmt_ = (StmtPtr)v;
|
||
|
||
for (const BufPtr& input_buf : v->buf_args()) {
|
||
if (loads_[input_buf].insert(last_stmt_).second) {
|
||
uses_[input_buf].push_back({last_stmt_, false});
|
||
}
|
||
}
|
||
|
||
IRVisitor::visit(v);
|
||
}
|
||
|
||
void visit(const ExternalCallWithAllocPtr& v) override {
|
||
for (const auto& out_buf : v->buf_out_args()) {
|
||
if (stores_[out_buf].insert(last_stmt_).second) {
|
||
uses_[out_buf].push_back({(StmtPtr)v, true});
|
||
}
|
||
}
|
||
last_stmt_ = (StmtPtr)v;
|
||
|
||
for (const auto& input_buf : v->buf_args()) {
|
||
if (loads_[input_buf].insert(last_stmt_).second) {
|
||
uses_[input_buf].push_back({last_stmt_, false});
|
||
}
|
||
}
|
||
|
||
IRVisitor::visit(v);
|
||
}
|
||
|
||
void visit(const LoadPtr& v) override {
|
||
if (loads_[v->buf()].insert(last_stmt_).second) {
|
||
uses_[v->buf()].push_back({last_stmt_, false});
|
||
}
|
||
IRVisitor::visit(v);
|
||
}
|
||
|
||
StmtPtr last_stmt_ = nullptr;
|
||
std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> uses_;
|
||
|
||
// Sets of loads and stores in order to keep the results unique
|
||
std::unordered_map<BufPtr, std::unordered_set<StmtPtr>> loads_;
|
||
std::unordered_map<BufPtr, std::unordered_set<StmtPtr>> stores_;
|
||
};
|
||
|
||
std::unordered_map<BufPtr, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
|
||
const StmtPtr& s) {
|
||
LoadOrStoreUseFinder uf;
|
||
return uf.findUses(s);
|
||
}
|
||
|
||
class ContainedStmtsFinder : public IRVisitor {
|
||
public:
|
||
// Simply list all Stores and Block that are children of the given stmt
|
||
const std::unordered_set<StmtPtr>& findContainedStmts(const StmtPtr& s) {
|
||
contained_.clear();
|
||
s->accept(this);
|
||
return contained_;
|
||
}
|
||
|
||
private:
|
||
void visit(const StorePtr& v) override {
|
||
contained_.insert((StmtPtr)v);
|
||
IRVisitor::visit(v);
|
||
}
|
||
void visit(const ExternalCallPtr& v) override {
|
||
contained_.insert((StmtPtr)v);
|
||
IRVisitor::visit(v);
|
||
}
|
||
void visit(const ExternalCallWithAllocPtr& v) override {
|
||
contained_.insert((StmtPtr)v);
|
||
IRVisitor::visit(v);
|
||
}
|
||
void visit(const BlockPtr& v) override {
|
||
contained_.insert((StmtPtr)v);
|
||
IRVisitor::visit(v);
|
||
}
|
||
|
||
std::unordered_set<StmtPtr> contained_;
|
||
};
|
||
|
||
class StmtDeleter : public IRMutator {
|
||
public:
|
||
StmtDeleter(const std::unordered_set<StmtPtr>& targets) : targets_(targets) {}
|
||
|
||
private:
|
||
StmtPtr mutate(const BlockPtr& v) override {
|
||
std::vector<StmtPtr> stmts;
|
||
|
||
for (const auto& s : v->stmts()) {
|
||
if (targets_.count(s) == 0) {
|
||
StmtPtr ns = s->accept_mutator(this);
|
||
if (ns) {
|
||
stmts.push_back(Stmt::clone(ns));
|
||
}
|
||
}
|
||
}
|
||
|
||
return Block::make(stmts);
|
||
}
|
||
|
||
const std::unordered_set<StmtPtr>& targets_;
|
||
};
|
||
|
||
void LoopNest::eliminateDeadStores() {
|
||
using namespace analysis;
|
||
MemDependencyChecker checker(getInputBufs(), getOutputBufs());
|
||
root_stmt_->accept(&checker);
|
||
|
||
std::unordered_set<StmtPtr> deadStores;
|
||
std::vector<std::shared_ptr<AccessInfo>> outputAccesses;
|
||
for (const auto& o : getOutputBufs()) {
|
||
outputAccesses.push_back(checker.output(o));
|
||
}
|
||
|
||
for (auto& info : checker.getHistory()) {
|
||
if (!info->isWrite()) {
|
||
continue;
|
||
}
|
||
bool found = false;
|
||
|
||
for (auto& output : outputAccesses) {
|
||
if (checker.dependsIndirectly(output, info)) {
|
||
found = true;
|
||
break;
|
||
}
|
||
}
|
||
|
||
if (!found) {
|
||
deadStores.insert(info->stmt());
|
||
}
|
||
}
|
||
|
||
StmtDeleter deleter(deadStores);
|
||
root_stmt_ = root_stmt_->accept_mutator(&deleter);
|
||
}
|
||
|
||
void LoopNest::prepareForCodegen() {
|
||
// Expand reduction ops.
|
||
ReductionExpander reduceExpander;
|
||
root_stmt_ = reduceExpander.expand(root_stmt_);
|
||
|
||
root_stmt_ = FlattenIndexes(root_stmt_);
|
||
}
|
||
|
||
namespace {
|
||
|
||
// This is extended from IRCloner instead of IRMutator because we want all
|
||
// the rest of the IR nodes (the ones not touched directly) to be cloned.
|
||
class IfThenElseReplacer : public IRCloner {
|
||
public:
|
||
IfThenElseReplacer(IfThenElsePtr to_replace, ExprPtr new_expr)
|
||
: to_replace_(std::move(to_replace)), new_expr_(std::move(new_expr)) {}
|
||
|
||
ExprPtr mutate(const IfThenElsePtr& i) override {
|
||
if (i == to_replace_) {
|
||
return new_expr_;
|
||
}
|
||
return IRCloner::mutate(i);
|
||
}
|
||
|
||
private:
|
||
IfThenElsePtr to_replace_;
|
||
ExprPtr new_expr_;
|
||
};
|
||
|
||
// Check if the given condition is optimizable.
|
||
// Specifically, this function looks for the following pattern:
|
||
// "var < expr"
|
||
//
|
||
// If this pattern is found, then this function:
|
||
// * sets `cond_var` to `var`,
|
||
// * sets `compared_value` to `expr`, and
|
||
// * returns true.
|
||
bool isConditionOptimizable(
|
||
const ExprPtr& condition,
|
||
VarPtr* cond_var,
|
||
ExprPtr* compared_value) {
|
||
auto cs = to<CompareSelect>(condition);
|
||
if (cs && cs->compare_select_op() == kLT) {
|
||
auto var = to<Var>(cs->lhs());
|
||
if (var) {
|
||
*cond_var = var;
|
||
*compared_value = cs->rhs();
|
||
return true;
|
||
}
|
||
}
|
||
return false;
|
||
}
|
||
|
||
// Checks if the given if-then-else expression is a conditional that is
|
||
// generated from `aten::cat`.
|
||
//
|
||
// The expected format of conditionals is:
|
||
// IfThenElse(var < val1? 1 : 0,
|
||
// IfThenElse (var < val2? 1 : 0,
|
||
// IfThenElse (var < val3? 1 : 0,
|
||
// sub-expr1,
|
||
// sub-expr2),
|
||
// sub-expr3),
|
||
// sub-expr4)
|
||
//
|
||
// If such a conditional is found, this function also sets:
|
||
// * cond_var to the condition variable found in this expression.
|
||
// * comp_values to the list of compared values in the condition expressions.
|
||
// * sub_exprs to the list of sub-expressions that are the result of this
|
||
// if-then-else expression.
|
||
bool isConditionalFromCat(
|
||
const IfThenElsePtr& ite,
|
||
VarPtr* cond_var,
|
||
std::vector<ExprPtr>* comp_values,
|
||
std::vector<ExprPtr>* sub_exprs) {
|
||
VarPtr var = nullptr;
|
||
ExprPtr comp_value;
|
||
if (isConditionOptimizable(ite->condition(), &var, &comp_value)) {
|
||
if (*cond_var == nullptr) {
|
||
*cond_var = var;
|
||
} else if (*cond_var != var) {
|
||
// Different condition variables found in nested if-then-else
|
||
// expressions. Can not optimize such cases.
|
||
return false;
|
||
}
|
||
auto true_ite = to<IfThenElse>(ite->true_value());
|
||
if (true_ite) {
|
||
if (!isConditionalFromCat(true_ite, cond_var, comp_values, sub_exprs)) {
|
||
return false;
|
||
}
|
||
} else {
|
||
sub_exprs->push_back(ite->true_value());
|
||
}
|
||
auto false_ite = to<IfThenElse>(ite->false_value());
|
||
if (false_ite) {
|
||
return false;
|
||
}
|
||
comp_values->push_back(comp_value);
|
||
sub_exprs->push_back(ite->false_value());
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
bool areConstantsAndSorted(const std::vector<ExprPtr>& comp_values) {
|
||
std::vector<int> comp_consts;
|
||
comp_consts.reserve(comp_values.size());
|
||
for (const auto& c : comp_values) {
|
||
if (!c->isConstant()) {
|
||
return false;
|
||
}
|
||
comp_consts.push_back(immediateAs<int>(c));
|
||
}
|
||
return std::is_sorted(comp_consts.begin(), comp_consts.end());
|
||
}
|
||
|
||
} // namespace
|
||
|
||
bool LoopNest::optimizeConditionals() {
|
||
// Consider every store in the root_stmt_ and try to optimize the
|
||
// conditionals in that store.
|
||
auto stores = NodeFinder<Store>::find(root_stmt_);
|
||
std::unordered_set<ForPtr> split_fors;
|
||
for (const auto& store : stores) {
|
||
VarPtr cond_var = nullptr;
|
||
// `comp_values` represent the list of compared values that will be
|
||
// collected as we check for the expected pattern. Since that will
|
||
// only include the RHS of the conditions in the if-then-else expressions
|
||
// we need to start with `0` which is the initial bound, given that we
|
||
// only handle normalized loops (check for this is done below).
|
||
std::vector<ExprPtr> comp_values;
|
||
std::vector<ExprPtr> sub_exprs;
|
||
auto ifthenelse_exprs = NodeFinder<IfThenElse>::find(store);
|
||
if (ifthenelse_exprs.empty()) {
|
||
continue;
|
||
}
|
||
// We only check if the first if-then-else expression in this store
|
||
// corresponds to a conditional of the required format. If there are more
|
||
// than one such conditional, optimizing them requires checking if the
|
||
// conditions are exactly the same across them and handling all of them
|
||
// together. Currently, this is not handled.
|
||
if (!isConditionalFromCat(
|
||
ifthenelse_exprs.front(), &cond_var, &comp_values, &sub_exprs)) {
|
||
continue;
|
||
}
|
||
TORCH_INTERNAL_ASSERT(
|
||
!comp_values.empty(),
|
||
buildErrorMessage(
|
||
"Expected at least one expression in optimizeConditional in the fuser."));
|
||
comp_values.insert(comp_values.begin(), immLike(comp_values[0], 0));
|
||
|
||
auto fors = getLoopStmtsFor(store);
|
||
if (cond_var != fors.back()->var()) {
|
||
// Currently, we only handle the case where the condition variable
|
||
// is the same as the inner-most loop variable.
|
||
// TODO: Handle all other cases here.
|
||
//
|
||
// In order to handle all other cases, the method `clone_and_replace`
|
||
// called below to clone the body of the loop with a new store needs
|
||
// to recursively handle cloning of the loops and other blocks it
|
||
// contains.
|
||
continue;
|
||
}
|
||
|
||
auto for_to_split = fors.back();
|
||
if (!LoopNest::isNormalized(for_to_split)) {
|
||
// Do not optimize this conditional since the condition variable
|
||
// refers to a loop that is not normalized.
|
||
continue;
|
||
}
|
||
if (split_fors.count(for_to_split)) {
|
||
// This loop has already been split while optimizing conditionals
|
||
// earlier.
|
||
//
|
||
// Optimizing multiple conditionals that require splitting the same loop
|
||
// is tricky. It requires checking if the conditions are exactly the same
|
||
// across them and handling all of them together by splitting the loop
|
||
// exactly once.
|
||
//
|
||
// Currently, this case is not supported.
|
||
continue;
|
||
}
|
||
split_fors.insert(for_to_split);
|
||
|
||
// `comp_values` needs to include the end bound, which is `for_to_split`
|
||
// stop value.
|
||
comp_values.push_back(for_to_split->stop());
|
||
|
||
// Check if all `comp_values` are constants and they are sorted.
|
||
if (!areConstantsAndSorted(comp_values)) {
|
||
continue;
|
||
}
|
||
|
||
// Remove all the if-then-else expressions from this store and create
|
||
// one loop per sub-expression.
|
||
std::vector<StmtPtr> split_loops;
|
||
auto cond_to_replace = ifthenelse_exprs.front();
|
||
for (size_t i = 0; i < sub_exprs.size(); ++i) {
|
||
IfThenElseReplacer ifthenelseReplacer(cond_to_replace, sub_exprs[i]);
|
||
auto new_store = store->accept_mutator(&ifthenelseReplacer);
|
||
auto new_for_body =
|
||
for_to_split->body()->clone_and_replace(store, new_store);
|
||
auto new_for = alloc<For>(
|
||
for_to_split->var(),
|
||
comp_values[i],
|
||
comp_values[i + 1],
|
||
new_for_body);
|
||
LoopNest::normalize(new_for);
|
||
split_loops.push_back(new_for);
|
||
}
|
||
auto par = to<Block>(for_to_split->get_parent());
|
||
par->replace_stmt(for_to_split, alloc<Block>(split_loops));
|
||
}
|
||
root_stmt_ = IRSimplifier::simplify(root_stmt_);
|
||
return true;
|
||
}
|
||
|
||
void LoopNest::vectorizeInnerLoops() {
|
||
std::vector<ForPtr> innerLoops;
|
||
std::vector<ForPtr> worklist;
|
||
|
||
// Find outer-most For loops
|
||
if (ForPtr rootF = to<For>(root_stmt_)) {
|
||
worklist.push_back(rootF);
|
||
} else if (BlockPtr body = to<Block>(root_stmt_)) {
|
||
std::vector<BlockPtr> blocks = {body};
|
||
while (!blocks.empty()) {
|
||
BlockPtr b = blocks.back();
|
||
blocks.pop_back();
|
||
|
||
for (const StmtPtr& s : *b) {
|
||
if (const ForPtr& f = to<For>(s)) {
|
||
worklist.push_back(f);
|
||
} else if (BlockPtr b2 = to<Block>(s)) {
|
||
blocks.push_back(b2);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Traverse the For loop nest find inner-most loops, which are
|
||
// vectorization candidates.
|
||
while (!worklist.empty()) {
|
||
ForPtr f = worklist.back();
|
||
worklist.pop_back();
|
||
|
||
bool containsSubLoops = false;
|
||
if (BlockPtr body = to<Block>(f->body())) {
|
||
for (const StmtPtr& s2 : *body) {
|
||
if (const ForPtr& f2 = to<For>(s2)) {
|
||
containsSubLoops = true;
|
||
worklist.push_back(f2);
|
||
}
|
||
}
|
||
}
|
||
|
||
if (!containsSubLoops) {
|
||
innerLoops.push_back(f);
|
||
}
|
||
}
|
||
|
||
// vectorize inner loops.
|
||
for (const ForPtr& loop : innerLoops) {
|
||
ForPtr split1;
|
||
ForPtr tail1;
|
||
|
||
static const int kBodyVectorWidth = 8;
|
||
splitWithTail(loop, kBodyVectorWidth, &split1, &tail1);
|
||
vectorize(split1);
|
||
|
||
if (tail1) {
|
||
ForPtr split2;
|
||
ForPtr tail2;
|
||
static const int kTailVectorWidth = 4;
|
||
splitWithTail(tail1, kTailVectorWidth, &split2, &tail2);
|
||
vectorize(split2);
|
||
}
|
||
}
|
||
}
|
||
|
||
void LoopNest::sliceHead(
|
||
const ForPtr& f,
|
||
int factor,
|
||
ForPtr* head,
|
||
ForPtr* tail) {
|
||
if (intValue(f->start()) && intValue(f->stop())) {
|
||
auto start_val = *intValue(f->start());
|
||
auto stop_val = *intValue(f->stop());
|
||
auto size_val = stop_val - start_val;
|
||
if (factor >= size_val) {
|
||
*head = f;
|
||
*tail = nullptr;
|
||
return;
|
||
}
|
||
}
|
||
|
||
if (!f) {
|
||
throw malformed_input("sliceHead attempted on null loop");
|
||
}
|
||
|
||
BlockPtr p = to<Block>(f->get_parent());
|
||
if (!p) {
|
||
throw malformed_input("sliceHead attempted on loop with no parent");
|
||
}
|
||
|
||
ExprPtr head_end = alloc<Min>(
|
||
alloc<Add>(f->start(), immLike(f->stop(), factor)), f->stop(), true);
|
||
*head = alloc<For>(f->var(), f->start(), head_end, Stmt::clone(f->body()));
|
||
p->insert_stmt_before(*head, f);
|
||
|
||
f->set_start(head_end);
|
||
*tail = f;
|
||
|
||
if (f->loop_options().is_gpu_block_index() ||
|
||
f->loop_options().is_gpu_thread_index()) {
|
||
LoopNest::normalize(*tail);
|
||
}
|
||
}
|
||
void LoopNest::sliceHead(const ForPtr& f, int factor) {
|
||
ForPtr head, tail;
|
||
sliceHead(f, factor, &head, &tail);
|
||
}
|
||
|
||
void LoopNest::sliceTail(
|
||
const ForPtr& f,
|
||
int factor,
|
||
ForPtr* head,
|
||
ForPtr* tail) {
|
||
if (intValue(f->start()) && intValue(f->stop())) {
|
||
auto start_val = *intValue(f->start());
|
||
auto stop_val = *intValue(f->stop());
|
||
auto size_val = stop_val - start_val;
|
||
if (factor >= size_val) {
|
||
*head = nullptr;
|
||
*tail = f;
|
||
return;
|
||
}
|
||
}
|
||
|
||
if (!f) {
|
||
throw malformed_input("sliceTail attempted on null loop");
|
||
}
|
||
|
||
BlockPtr p = to<Block>(f->get_parent());
|
||
if (!p) {
|
||
throw malformed_input("sliceTail attempted on loop with no parent");
|
||
}
|
||
|
||
ExprPtr tail_start = alloc<Max>(
|
||
f->start(), alloc<Sub>(f->stop(), immLike(f->stop(), factor)), true);
|
||
*tail = alloc<For>(f->var(), tail_start, f->stop(), Stmt::clone(f->body()));
|
||
p->insert_stmt_after(*tail, f);
|
||
|
||
f->set_stop(tail_start);
|
||
*head = f;
|
||
|
||
if (f->loop_options().is_gpu_block_index() ||
|
||
f->loop_options().is_gpu_thread_index()) {
|
||
LoopNest::normalize(*head);
|
||
}
|
||
}
|
||
void LoopNest::sliceTail(const ForPtr& f, int factor) {
|
||
ForPtr head, tail;
|
||
sliceTail(f, factor, &head, &tail);
|
||
}
|
||
|
||
void LoopNest::splitWithTail(const ForPtr& f, int factor) {
|
||
ForPtr inner, tail;
|
||
splitWithTail(f, factor, &inner, &tail);
|
||
}
|
||
|
||
void LoopNest::splitWithTail(
|
||
const ForPtr& f,
|
||
int factor,
|
||
ForPtr* inner,
|
||
ForPtr* tail) {
|
||
if (!f) {
|
||
throw malformed_input("splitWithTail attempted on null loop");
|
||
}
|
||
|
||
BlockPtr p = to<Block>(f->get_parent());
|
||
if (!p) {
|
||
throw malformed_input("splitWithTail attempted on loop with no parent");
|
||
}
|
||
|
||
// Normalize the loop to simplify start and stop bound computation
|
||
normalize(f);
|
||
|
||
bool tail_is_needed = true;
|
||
if (intValue(f->start()) && intValue(f->stop())) {
|
||
auto const start_val = *intValue(f->start());
|
||
auto const stop_val = *intValue(f->stop());
|
||
auto const size_val = stop_val - start_val;
|
||
auto const tail_size = size_val % factor;
|
||
if (tail_size == 0) {
|
||
tail_is_needed = false;
|
||
}
|
||
}
|
||
|
||
ExprPtr factor_expr = immLike(f->stop(), factor);
|
||
ExprPtr size = alloc<Sub>(f->stop(), f->start());
|
||
ExprPtr split_count = alloc<Div>(size, factor_expr);
|
||
ExprPtr tail_size = alloc<Mod>(size, factor_expr);
|
||
|
||
const std::string& loop_var_name = f->var()->name_hint();
|
||
Dtype loop_var_dtype = f->var()->dtype();
|
||
|
||
VarPtr i_inner = alloc<Var>(loop_var_name + "_inner", loop_var_dtype);
|
||
VarPtr i_outer = alloc<Var>(loop_var_name + "_outer", loop_var_dtype);
|
||
|
||
// x -> x.outer * inner.size + x.inner
|
||
ExprPtr combined_index1 =
|
||
alloc<Add>(alloc<Mul>(i_outer, factor_expr), i_inner);
|
||
|
||
if (tail_is_needed) {
|
||
VarPtr i_tail = alloc<Var>(loop_var_name + "_tail", loop_var_dtype);
|
||
// x -> x.tail + outer.size * inner.size
|
||
ExprPtr combined_index2 =
|
||
alloc<Add>(i_tail, alloc<Mul>(split_count, factor_expr));
|
||
|
||
StmtPtr body_tail =
|
||
SubstituteInClone(f->body(), {{f->var(), combined_index2}});
|
||
*tail = alloc<For>(i_tail, immLike(tail_size, 0), tail_size, body_tail);
|
||
|
||
p->insert_stmt_after(*tail, f);
|
||
} else {
|
||
*tail = nullptr;
|
||
}
|
||
|
||
StmtPtr body_inner =
|
||
Substitute(f->removeBody(), {{f->var(), combined_index1}});
|
||
|
||
*inner =
|
||
alloc<For>(i_inner, immLike(factor_expr, 0), factor_expr, body_inner);
|
||
// The input loop `f` will be the outer loop after split.
|
||
f->set_var(i_outer);
|
||
f->set_start(immLike(split_count, 0));
|
||
f->set_stop(split_count);
|
||
f->set_body(*inner);
|
||
}
|
||
|
||
void LoopNest::splitWithMask(const ForPtr& f, int factor) {
|
||
ForPtr inner;
|
||
splitWithMask(f, factor, &inner);
|
||
}
|
||
|
||
void LoopNest::splitWithMask(const ForPtr& f, int factor, ForPtr* inner) {
|
||
BlockPtr p = to<Block>(f->get_parent());
|
||
if (!p) {
|
||
std::cerr << "Parent is not a Block!\n";
|
||
return;
|
||
}
|
||
|
||
bool tail_is_needed = true;
|
||
ExprPtr start = IRSimplifier::simplify(f->start());
|
||
ExprPtr stop = IRSimplifier::simplify(f->stop());
|
||
if (start->isConstant() && stop->isConstant()) {
|
||
auto start_val = *intValue(start);
|
||
auto stop_val = *intValue(stop);
|
||
auto size_val = stop_val - start_val;
|
||
auto tail_size = size_val % factor;
|
||
if (tail_size == 0) {
|
||
tail_is_needed = false;
|
||
}
|
||
}
|
||
|
||
auto factor_expr = immLike(f->stop(), factor);
|
||
ExprPtr size = alloc<Sub>(f->stop(), f->start());
|
||
// split_count = (size + factor - 1) / factor
|
||
ExprPtr split_count = alloc<Div>(
|
||
alloc<Sub>(alloc<Add>(size, factor_expr), immLike(size, 1)), factor_expr);
|
||
|
||
const std::string& loop_var_name = f->var()->name_hint();
|
||
Dtype loop_var_dtype = f->var()->dtype();
|
||
|
||
VarPtr i_inner = alloc<Var>(loop_var_name + "_inner", loop_var_dtype);
|
||
VarPtr i_outer = alloc<Var>(loop_var_name + "_outer", loop_var_dtype);
|
||
|
||
// x -> x.outer * inner.size + x.inner
|
||
ExprPtr combined_index =
|
||
alloc<Add>(alloc<Mul>(i_outer, factor_expr), i_inner);
|
||
|
||
StmtPtr body_inner = f->removeBody();
|
||
// TODO: is it ok that we're doing it eagerly? In the other implementation we
|
||
// are only materializing predicates at the last, lowering, step.
|
||
if (tail_is_needed) {
|
||
auto start = intValue(f->start());
|
||
if (!start || *start != 0) {
|
||
throw unimplemented_lowering();
|
||
}
|
||
|
||
ExprPtr predicate =
|
||
CompareSelect::make(ExprHandle(f->var()), ExprHandle(f->stop()), kLT)
|
||
.node();
|
||
body_inner = Cond::make(ExprHandle(predicate), body_inner, nullptr);
|
||
}
|
||
body_inner = Substitute(body_inner, {{f->var(), combined_index}});
|
||
|
||
*inner =
|
||
alloc<For>(i_inner, immLike(factor_expr, 0), factor_expr, body_inner);
|
||
// The input loop `f` will be the outer loop after split.
|
||
f->set_var(i_outer);
|
||
f->set_start(immLike(split_count, 0));
|
||
f->set_stop(split_count);
|
||
f->set_body(*inner);
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::distributeLoop(
|
||
const ForPtr& loop,
|
||
const std::unordered_set<StmtPtr>& pivots) {
|
||
TORCH_INTERNAL_ASSERT(
|
||
loop,
|
||
buildErrorMessage(
|
||
"Expected non-null loop in distributeLoop in the fuser."));
|
||
auto root = loop->get_parent();
|
||
if (root == nullptr) {
|
||
throw malformed_input("Loop without parent: ", loop);
|
||
}
|
||
auto root_block = to<Block>(root);
|
||
if (root_block == nullptr) {
|
||
throw malformed_input(
|
||
"Loop's parent must be a Block, instead found ", root);
|
||
}
|
||
|
||
// Extract bodies for all the loops after distribution.
|
||
std::vector<BlockPtr> new_loop_bodies;
|
||
auto new_loop_body = alloc<Block>(std::vector<StmtPtr>({}));
|
||
while (!loop->body()->empty()) {
|
||
auto s = loop->body()->front();
|
||
loop->body()->remove_stmt(s);
|
||
new_loop_body->append_stmt(s);
|
||
if (pivots.count(s)) {
|
||
new_loop_bodies.push_back(new_loop_body);
|
||
new_loop_body = alloc<Block>(std::vector<StmtPtr>({}));
|
||
}
|
||
}
|
||
if (!new_loop_body->empty()) {
|
||
new_loop_bodies.push_back(new_loop_body);
|
||
}
|
||
|
||
// The first loop body has to be in the original loop.
|
||
loop->body()->splice(loop->body()->begin(), new_loop_bodies.front());
|
||
std::vector<ForPtr> new_loops = {loop};
|
||
|
||
// Create loops for all the remaining blocks.
|
||
// Add all the new loops to the parent block.
|
||
for (size_t i = 1; i < new_loop_bodies.size(); ++i) {
|
||
auto new_loop = loop->cloneWithNewBody(new_loop_bodies[i]);
|
||
root_block->insert_stmt_after(new_loop, new_loops.back());
|
||
new_loops.push_back(new_loop);
|
||
}
|
||
|
||
return new_loops;
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::distributeLoop(const ForPtr& loop) {
|
||
std::unordered_set<StmtPtr> stmtsInBlock(
|
||
loop->body()->begin(), loop->body()->end());
|
||
return distributeLoop(loop, stmtsInBlock);
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::distributeLoopAndParents(const ForPtr& loop) {
|
||
auto parentLoop = getParentLoop(loop);
|
||
auto result = distributeLoop(loop);
|
||
if (parentLoop) {
|
||
return distributeLoopAndParents(parentLoop);
|
||
}
|
||
return result;
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::distributeLoopOverInnerLoops(const ForPtr& loop) {
|
||
auto loops = NodeFinder<For>::find(loop);
|
||
std::unordered_set<StmtPtr> loopsSet(loops.begin(), loops.end());
|
||
return distributeLoop(loop, loopsSet);
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::distributeLoopAndParentsOverInnerLoops(
|
||
const ForPtr& loop) {
|
||
auto parentLoop = getParentLoop(loop);
|
||
auto result = distributeLoopOverInnerLoops(loop);
|
||
if (parentLoop) {
|
||
return distributeLoopAndParentsOverInnerLoops(parentLoop);
|
||
}
|
||
return result;
|
||
}
|
||
|
||
static bool areEqual(const ExprPtr& expr1, const ExprPtr& expr2) {
|
||
auto diff = IRSimplifier::simplify(alloc<Sub>(expr1, expr2));
|
||
return diff->isConstant() && (immediateAs<int>(diff) == 0);
|
||
};
|
||
|
||
static bool doesExprContainAnyVar(
|
||
const ExprPtr& expr,
|
||
const std::unordered_set<VarPtr>& vars) {
|
||
for (const auto& v : VarFinder::find(expr)) {
|
||
if (vars.count(v)) {
|
||
return true;
|
||
}
|
||
}
|
||
return false;
|
||
}
|
||
|
||
// Returns true if the given list of indices refer to two accesses
|
||
// that are loop-independent w.r.t. the given list of outer loop
|
||
// variables.
|
||
static bool areIndicesLoopIndependent(
|
||
const std::vector<ExprPtr>& expr_list1,
|
||
const std::vector<ExprPtr>& expr_list2,
|
||
const std::unordered_set<VarPtr>& outer_loop_vars) {
|
||
if (expr_list1.size() != expr_list2.size()) {
|
||
return false;
|
||
}
|
||
for (size_t i = 0; i < expr_list1.size(); ++i) {
|
||
const auto& expr1 = expr_list1[i];
|
||
const auto& expr2 = expr_list2[i];
|
||
if (doesExprContainAnyVar(expr1, outer_loop_vars) ||
|
||
doesExprContainAnyVar(expr2, outer_loop_vars)) {
|
||
if (!areEqual(expr1, expr2)) {
|
||
return false;
|
||
}
|
||
}
|
||
}
|
||
return true;
|
||
}
|
||
|
||
bool LoopNest::hasLoopCarriedDependence(const ForPtr& loop) {
|
||
analysis::MemDependencyChecker analyzer;
|
||
loop->accept(&analyzer);
|
||
|
||
std::unordered_set<VarPtr> outer_loop_vars = {loop->var()};
|
||
auto outer_loops = LoopNest::getEnclosingLoopNest(loop);
|
||
for (const auto& l : outer_loops) {
|
||
outer_loop_vars.insert(l->var());
|
||
}
|
||
|
||
// High-level algorithm to check if two accesses to a buffer, A and B, one of
|
||
// which is a Store, result in a loop-carried dependence:
|
||
// 1. For every pair of index expressions, Ai and Bi, that refer to a dim
|
||
// of A and B, if one of the following conditions are satisfied:
|
||
// a) Ai and Bi are equal (OR)
|
||
// b) Both Ai and Bi do not contain any outer-loop variables
|
||
// then, the dependence between A and B is a loop-independent
|
||
// dependence. This is because, in the case of b), those index
|
||
// expressions do not affect the ordering of accesses A and B.
|
||
// 2. If condition 1) is not satisfied:
|
||
// a) if the bounds on the accesses overlap, then this is a
|
||
// loop-carried dependence.
|
||
// b) if the bounds on the accesses do not overlap, then there is no
|
||
// dependence.
|
||
//
|
||
// NOTE: Since we check for equality of index expressions whenever outer
|
||
// loop variables are involved, this may incorrectly report some cases as
|
||
// having a loop-carried dependence. It is impractical to handle all
|
||
// possible cases here, so, we are being conservative and allow for
|
||
// some false positives. While this will prevent some loop fusion
|
||
// opportunities, that should be a small fraction of the cases that are
|
||
// allowed.
|
||
//
|
||
// Implementation:
|
||
//
|
||
// For every pair of statements, S1 and S2, in the loop:
|
||
// * Get the loads and stores in S1 and S2.
|
||
// * For every store in S1 and load in S2 to the same buffer, if the index
|
||
// expressions are not equal and there is an overlap in accesses, return
|
||
// true to indicate a loop-carried dependence.
|
||
// * For every load in S1 and store in S2 to the same buffer, if the index
|
||
// expressions are not equal and there is an overlap in accesses, return
|
||
// true to indicate a loop-carried dependence.
|
||
// * For every store in S1 and store in S2 to the same buffer, if the index
|
||
// expressions are not equal and there is an overlap in accesses, return
|
||
// true to indicate a loop-carried dependence.
|
||
for (auto it1 = loop->body()->begin(); it1 != loop->body()->end(); ++it1) {
|
||
for (auto it2 = std::next(it1); it2 != loop->body()->end(); ++it2) {
|
||
auto aStores = NodeFinder<Store>::find(*it1);
|
||
auto aLoads = NodeFinder<Load>::find(*it1);
|
||
auto bStores = NodeFinder<Store>::find(*it2);
|
||
auto bLoads = NodeFinder<Load>::find(*it2);
|
||
// ReadAfterWrite
|
||
for (auto& aStore : aStores) {
|
||
for (auto& bLoad : bLoads) {
|
||
if (aStore->buf() == bLoad->buf()) {
|
||
if (!areIndicesLoopIndependent(
|
||
aStore->indices(), bLoad->indices(), outer_loop_vars)) {
|
||
if (isOverlapping(analyzer, aStore, bLoad)) {
|
||
return true;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// WriteAfterRead
|
||
for (auto& bStore : bStores) {
|
||
for (auto& aLoad : aLoads) {
|
||
if (bStore->buf() == aLoad->buf()) {
|
||
if (!areIndicesLoopIndependent(
|
||
bStore->indices(), aLoad->indices(), outer_loop_vars)) {
|
||
if (isOverlapping(analyzer, bStore, aLoad)) {
|
||
return true;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
// WriteAfterWrite
|
||
for (auto& aStore : aStores) {
|
||
for (auto& bStore : bStores) {
|
||
if (aStore->buf() == bStore->buf()) {
|
||
if (!areIndicesLoopIndependent(
|
||
aStore->indices(), bStore->indices(), outer_loop_vars)) {
|
||
if (isOverlapping(analyzer, aStore, bStore)) {
|
||
return true;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return false;
|
||
}
|
||
|
||
bool LoopNest::unsafeFuseLoops(
|
||
const std::vector<ForPtr>& loops,
|
||
ForPtr* fused) {
|
||
if (loops.empty()) {
|
||
return false;
|
||
}
|
||
if (loops.size() == 1) {
|
||
*fused = loops.front();
|
||
return true;
|
||
}
|
||
|
||
// Check if all the loops have the same parent.
|
||
auto root = loops.front()->get_parent();
|
||
for (const auto& l : loops) {
|
||
auto par = l->get_parent();
|
||
if (par == nullptr) {
|
||
return false;
|
||
}
|
||
if (par != root) {
|
||
return false;
|
||
}
|
||
}
|
||
auto root_block = to<Block>(root);
|
||
if (root_block == nullptr) {
|
||
return false;
|
||
}
|
||
|
||
// Currently, we only handle cases where there are no statements between
|
||
// the given loops in their parents body. We can possibly relax this
|
||
// constraint by allowing statements that do not affect the loops being
|
||
// fused by performing some dependency analysis. TODO.
|
||
auto it = root_block->begin();
|
||
for (; it != root_block->end(); ++it) {
|
||
if (*it == loops.front()) {
|
||
break;
|
||
}
|
||
}
|
||
TORCH_INTERNAL_ASSERT(
|
||
it != root_block->end(),
|
||
buildErrorMessage(
|
||
"Could not find the given loop in the root stmt in unsafeFuseLoop the fuser."));
|
||
for (const auto& l : loops) {
|
||
if (*it != l) {
|
||
return false;
|
||
}
|
||
++it;
|
||
}
|
||
|
||
const auto& first_loop = loops.front();
|
||
// Fuse the loops by taking all the statements from the second loops
|
||
// onwards and moving them into the first loop's body.
|
||
// This way the final fused loop will be the same as the first loop.
|
||
for (size_t i = 1; i < loops.size(); ++i) {
|
||
auto body = to<Block>(SubstituteInClone(
|
||
loops[i]->body(), {{loops[i]->var(), first_loop->var()}}));
|
||
first_loop->body()->splice(first_loop->body()->end(), body);
|
||
root_block->remove_stmt(loops[i]);
|
||
}
|
||
|
||
*fused = loops.front();
|
||
return true;
|
||
}
|
||
|
||
bool LoopNest::fuseLoops(const std::vector<ForPtr>& loops, ForPtr* fused) {
|
||
if (loops.empty()) {
|
||
return false;
|
||
}
|
||
if (loops.size() == 1) {
|
||
*fused = loops.front();
|
||
return true;
|
||
}
|
||
|
||
// Check if bounds are the same for all the loops.
|
||
const auto& first_loop = loops.front();
|
||
auto first_loop_start = IRSimplifier::simplify(first_loop->start());
|
||
auto first_loop_stop = IRSimplifier::simplify(first_loop->stop());
|
||
for (size_t i = 1; i < loops.size(); ++i) {
|
||
const auto& curr_loop = loops[i];
|
||
auto curr_loop_start = IRSimplifier::simplify(curr_loop->start());
|
||
auto curr_loop_stop = IRSimplifier::simplify(curr_loop->stop());
|
||
if (!areEqual(curr_loop_start, first_loop_start)) {
|
||
return false;
|
||
}
|
||
if (!areEqual(curr_loop_stop, first_loop_stop)) {
|
||
return false;
|
||
}
|
||
}
|
||
|
||
// We need to check if fusing the loops results in a loop-carried dependence.
|
||
// This check can be done only after the loops are fused into one. But if the
|
||
// check is violated, we need to return the given loops in the original form.
|
||
// So, we create a clone of all the loops, fuse them and check for this.
|
||
std::vector<ForPtr> loops_copy;
|
||
loops_copy.reserve(loops.size());
|
||
BlockPtr parent = alloc<Block>(std::vector<StmtPtr>({}));
|
||
for (auto& l : loops) {
|
||
auto l_copy = Stmt::clone(l);
|
||
loops_copy.push_back(to<For>(l_copy));
|
||
parent->append_stmt(l_copy);
|
||
}
|
||
ForPtr fused_copy;
|
||
bool ret = unsafeFuseLoops(loops_copy, &fused_copy);
|
||
if (!ret || hasLoopCarriedDependence(fused_copy)) {
|
||
return false;
|
||
}
|
||
|
||
// Now that all conditions are satisfied, we fuse the given loops.
|
||
return unsafeFuseLoops(loops, fused);
|
||
}
|
||
|
||
ForPtr LoopNest::findOuterFor(ForPtr a, ForPtr b) {
|
||
StmtPtr s = b; // guess b is the latter.
|
||
while (s != nullptr) {
|
||
if (s == a) {
|
||
// yes, b is after a.
|
||
return a;
|
||
}
|
||
s = s->get_parent();
|
||
}
|
||
|
||
// check that the two are in the same loop nest.
|
||
s = a;
|
||
while (s != nullptr) {
|
||
if (s == b) {
|
||
// a is after b.
|
||
return b;
|
||
}
|
||
s = s->get_parent();
|
||
}
|
||
|
||
// a and b have no relationship.
|
||
return nullptr;
|
||
}
|
||
|
||
void LoopNest::reorderAxis(const ForPtr& a, const ForPtr& b) {
|
||
if (a == b) {
|
||
// nothing to do.
|
||
return;
|
||
}
|
||
// find inner and outer.
|
||
ForPtr outer = findOuterFor(a, b);
|
||
if (outer == nullptr) {
|
||
throw std::runtime_error("Reordered a loop not in LoopNest");
|
||
}
|
||
|
||
ForPtr inner = a == outer ? b : a;
|
||
std::deque<ForPtr> internal_axes;
|
||
|
||
// Find relevant axes, store reversed.
|
||
StmtPtr s = inner;
|
||
while (s != outer) {
|
||
if (const ForPtr& f = to<For>(s)) {
|
||
internal_axes.push_back(f);
|
||
}
|
||
|
||
s = s->get_parent();
|
||
}
|
||
|
||
internal_axes.push_back(outer);
|
||
|
||
BlockPtr root = to<Block>(outer->get_parent());
|
||
CHECK(root);
|
||
|
||
// Do a shallow copy of the inner blocks.
|
||
BlockPtr body = alloc<Block>(std::vector<StmtPtr>({}));
|
||
body->splice(body->end(), inner->body());
|
||
|
||
const ForPtr& before{outer};
|
||
ForPtr after{nullptr};
|
||
ForPtr last = internal_axes.front();
|
||
StmtPtr newInner = body;
|
||
|
||
s = inner;
|
||
while (s != outer) {
|
||
if (auto cond = to<Cond>(s->get_parent())) {
|
||
if (s == cond->true_stmt()) {
|
||
newInner = cond->cloneWithNewBody(newInner);
|
||
} else {
|
||
// s is the false branch of Cond
|
||
newInner = cond->cloneWithNewBodies(
|
||
alloc<Block>(std::vector<StmtPtr>({})), newInner);
|
||
}
|
||
}
|
||
s = s->get_parent();
|
||
}
|
||
|
||
// This is the major complexity in loop reordering: handling statements not in
|
||
// the straight line of the reorder. To handle this we partition the tree into
|
||
// the section before the critical path and after the critical path.
|
||
//
|
||
// An example of this pattern is:
|
||
// for i in ..
|
||
// Statement A
|
||
// for j in ..
|
||
// Statement B
|
||
// Statement C
|
||
//
|
||
// When reordering loop i and j we need to ensure that Statement A and C are
|
||
// still both executed with the loop extents of i, and that the three
|
||
// statements are not reordered (as much as possible).
|
||
for (const auto& loop : internal_axes) {
|
||
// If the inner loop had a component after the loop we must wrap it in a For
|
||
// loop matching this level of the tree.
|
||
if (after != nullptr) {
|
||
after = loop->cloneWithNewBody(after);
|
||
}
|
||
|
||
bool pastMidpoint = false;
|
||
bool hadBeforeStmts = false;
|
||
for (auto I = loop->body()->begin(), E = loop->body()->end(); I != E;) {
|
||
// Be careful not to invalidate the iterator.
|
||
StmtPtr s = *(I++);
|
||
if (s == last) {
|
||
// This is the midpoint.
|
||
loop->body()->remove_stmt(s);
|
||
if (!hadBeforeStmts) {
|
||
// If there were no existing statements this loop does not need to be
|
||
// preserved and we can roll it into the above loop.
|
||
last = loop;
|
||
}
|
||
pastMidpoint = true;
|
||
} else if (pastMidpoint) {
|
||
// Statements after the reordered path must be moved to a new tree after
|
||
// the reordered statement has occurred to preserve ordering.
|
||
loop->body()->remove_stmt(s);
|
||
if (after == nullptr) {
|
||
after = loop->cloneWithNewBody(s);
|
||
} else {
|
||
after->body()->append_stmt(s);
|
||
}
|
||
} else {
|
||
// We can leave any statements before the reordered loop alone, so long
|
||
// as we preserve the loop structure.
|
||
hadBeforeStmts = true;
|
||
}
|
||
}
|
||
}
|
||
|
||
// now we can actually reorder the chosen axes.
|
||
std::swap(internal_axes.front(), internal_axes.back());
|
||
|
||
// Create the reordered internals:
|
||
for (const auto& loop : internal_axes) {
|
||
newInner = loop->cloneWithNewBody(newInner);
|
||
}
|
||
|
||
// Append the new statements to the root of the tree.
|
||
if (before->body()->nstmts() == 0) {
|
||
// If the top level is now empty, eliminate it.
|
||
root->replace_stmt(before, newInner);
|
||
} else {
|
||
root->insert_stmt_after(newInner, before);
|
||
}
|
||
|
||
if (after) {
|
||
root->insert_stmt_after(after, newInner);
|
||
}
|
||
}
|
||
|
||
static bool isTrivialPermutation(const std::vector<size_t>& permutation) {
|
||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||
if (permutation[i] != i) {
|
||
return false;
|
||
}
|
||
}
|
||
return true;
|
||
}
|
||
|
||
static bool isValidPermutation(std::vector<size_t> permutation) {
|
||
std::sort(permutation.begin(), permutation.end());
|
||
return isTrivialPermutation(permutation);
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::reorder(
|
||
const std::vector<ForPtr>& loops,
|
||
const std::vector<size_t>& permutation) {
|
||
if (loops.size() != permutation.size()) {
|
||
throw malformed_input("invalid permutation size");
|
||
}
|
||
if (isTrivialPermutation(permutation)) {
|
||
return loops;
|
||
}
|
||
if (!isValidPermutation(permutation)) {
|
||
throw malformed_input("invalid permutation for reorder");
|
||
}
|
||
if (loops.size() < 2) {
|
||
return loops;
|
||
}
|
||
if (!areLoopsPerfectlyNested(loops)) {
|
||
throw malformed_input("reorder is only allowed on perfectly nested loops");
|
||
}
|
||
|
||
auto parent = to<Block>(loops.front()->get_parent());
|
||
if (parent == nullptr) {
|
||
throw malformed_input("parent of the loops must be a Block");
|
||
}
|
||
|
||
// Reorder the loops according to the permutation.
|
||
std::vector<ForPtr> result(loops.size());
|
||
for (size_t i = 0; i < loops.size(); ++i) {
|
||
result[i] = loops[permutation[i]];
|
||
}
|
||
|
||
// Remove the bodies from all the loops.
|
||
auto innermost_body = loops.back()->removeBody();
|
||
// We use an empty block statement to replace the outermost loop
|
||
// so that we know the position where the outermost reordered loop
|
||
// is to be inserted.
|
||
auto empty_block = alloc<Block>(std::vector<StmtPtr>({}));
|
||
parent->replace_stmt(loops.front(), empty_block);
|
||
for (size_t i = 1; i < loops.size(); ++i) {
|
||
auto block = to<Block>(loops[i]->get_parent());
|
||
TORCH_INTERNAL_ASSERT(
|
||
block,
|
||
buildErrorMessage(
|
||
"Expected parent stmt to be a non-null Block in reorder transformation the fuser."));
|
||
block->remove_stmt(loops[i]);
|
||
}
|
||
|
||
// Set the new bodies after reorder for all the loops.
|
||
for (size_t i = 0; i < result.size() - 1; ++i) {
|
||
result[i]->set_body(result[i + 1]);
|
||
}
|
||
result.back()->set_body(innermost_body);
|
||
parent->replace_stmt(empty_block, result.front());
|
||
return result;
|
||
}
|
||
|
||
ForPtr LoopNest::getLoopAt(ForPtr root, const std::vector<int>& indices) const {
|
||
if (indices.empty()) {
|
||
return root;
|
||
}
|
||
if (root == nullptr) {
|
||
throw malformed_input("root loop is null");
|
||
}
|
||
|
||
ForPtr curr = std::move(root);
|
||
for (auto i : indices) {
|
||
if (i < 0 || curr->body()->nstmts() <= static_cast<size_t>(i)) {
|
||
return nullptr;
|
||
}
|
||
std::list<StmtPtr>::iterator stmtp = curr->body()->begin();
|
||
std::advance(stmtp, i);
|
||
curr = to<For>(*stmtp);
|
||
if (curr == nullptr) {
|
||
return nullptr;
|
||
}
|
||
}
|
||
|
||
return curr;
|
||
}
|
||
|
||
ForPtr LoopNest::tile(
|
||
const ForPtr& x,
|
||
const ForPtr& y,
|
||
int x_factor,
|
||
int y_factor) {
|
||
auto parent = to<Block>(x->get_parent());
|
||
if (parent == nullptr) {
|
||
throw malformed_input("parent of the loops must be a Block");
|
||
}
|
||
if (!areLoopsPerfectlyNested({x, y})) {
|
||
throw malformed_input("two loops must be perfectly nested");
|
||
}
|
||
|
||
// Split x, y axes by x_factor and y_factor
|
||
ForPtr yi, ytail;
|
||
splitWithTail(y, y_factor, &yi, &ytail);
|
||
ForPtr xi, xtail;
|
||
splitWithTail(x, x_factor, &xi, &xtail);
|
||
|
||
// Distribute xi over yo and ytail so we can manipulate the loop order of {xo,
|
||
// xi, yo, yi}
|
||
auto loops = distributeLoop(xi);
|
||
|
||
// For {xi, yo, yi}, reorder the axes to be yo, xi, yi
|
||
xi = loops.front();
|
||
ForPtr yo = to<For>(xi->body()->stmts().front());
|
||
CHECK(yo);
|
||
reorder({xi, yo}, {1, 0});
|
||
|
||
// For {xi, ytail}, reorder the axes to be ytail, xi
|
||
if (loops.size() == 2) {
|
||
xi = loops.back();
|
||
ytail = to<For>(xi->body()->stmts().front());
|
||
CHECK(ytail);
|
||
reorder({xi, ytail}, {1, 0});
|
||
}
|
||
|
||
return xtail;
|
||
}
|
||
|
||
bool LoopNest::areLoopsPerfectlyNested(const std::vector<ForPtr>& loops) {
|
||
if (loops.size() < 2) {
|
||
return true;
|
||
}
|
||
for (size_t i = 0; i < loops.size() - 1; ++i) {
|
||
auto loop_body = loops[i]->body();
|
||
if (loop_body->nstmts() != 1 || loop_body->front() != loops[i + 1]) {
|
||
return false;
|
||
}
|
||
}
|
||
return true;
|
||
}
|
||
|
||
void LoopNest::fullUnroll(const ForPtr& f, StmtPtr* unrolled) {
|
||
BlockPtr p = to<Block>(f->get_parent());
|
||
if (!f) {
|
||
throw malformed_input("unroll attempted on null loop");
|
||
} else if (!p) {
|
||
throw malformed_input("unroll attempted on loop with no parent");
|
||
}
|
||
|
||
auto start_expr = IRSimplifier::simplify(f->start());
|
||
auto stop_expr = IRSimplifier::simplify(f->stop());
|
||
if (!start_expr->isConstant()) {
|
||
throw std::runtime_error("Can't unroll due to non-constant loop start!");
|
||
}
|
||
if (!stop_expr->isConstant()) {
|
||
throw std::runtime_error("Can't unroll due to non-constant loop stop!");
|
||
}
|
||
|
||
std::vector<StmtPtr> unrolled_stmts;
|
||
int start_val = immediateAs<int>(start_expr);
|
||
int stop_val = immediateAs<int>(stop_expr);
|
||
for (int current = start_val; current < stop_val; ++current) {
|
||
for (const auto& stmt : f->body()->stmts()) {
|
||
unrolled_stmts.push_back(SubstituteInClone(
|
||
stmt, {{f->var(), getImmediateByType(f->var()->dtype(), current)}}));
|
||
}
|
||
}
|
||
*unrolled = alloc<Block>(unrolled_stmts);
|
||
*unrolled = IRSimplifier::simplify(*unrolled);
|
||
|
||
p->replace_stmt(f, *unrolled);
|
||
}
|
||
|
||
void LoopNest::fullUnroll(const ForPtr& f) {
|
||
StmtPtr unrolled;
|
||
fullUnroll(f, &unrolled);
|
||
}
|
||
|
||
void LoopNest::unroll(const ForPtr& f, int factor, ForPtr* tail) {
|
||
if (factor < 2) {
|
||
return;
|
||
}
|
||
ForPtr inner;
|
||
splitWithTail(f, factor, &inner, tail);
|
||
fullUnroll(inner);
|
||
}
|
||
|
||
void LoopNest::unroll(const ForPtr& f, int factor) {
|
||
ForPtr tail;
|
||
unroll(f, factor, &tail);
|
||
}
|
||
|
||
bool LoopNest::isNormalized(const ForPtr& f) {
|
||
if (f->start()->isConstant()) {
|
||
return immediateAs<int>(f->start()) == 0;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
bool LoopNest::normalize(const ForPtr& f) {
|
||
if (!f) {
|
||
throw malformed_input("normalize attempted on null loop");
|
||
}
|
||
|
||
if (isNormalized(f)) {
|
||
// No need to normalize anymore here.
|
||
return false;
|
||
}
|
||
|
||
auto for_body_normalized = Substitute(
|
||
f->body(),
|
||
{{f->var(), (VarHandle(f->var()) + ExprHandle(f->start())).node()}});
|
||
f->set_body(IRSimplifier::simplify(for_body_normalized));
|
||
f->set_stop(IRSimplifier::simplify(alloc<Sub>(f->stop(), f->start())));
|
||
f->set_start(immLike(f->stop(), 0));
|
||
return true;
|
||
}
|
||
|
||
// This function expects that there are 'num' loops perfectly nested within
|
||
// and including 'f'.
|
||
std::vector<ForPtr> LoopNest::getLoopStmtsInLoopNest(
|
||
const ForPtr& f,
|
||
size_t num) {
|
||
std::vector<ForPtr> loops(num);
|
||
ForPtr curr_for = f;
|
||
loops[0] = curr_for;
|
||
for (size_t i = 1; i < num; ++i) {
|
||
TORCH_INTERNAL_ASSERT(
|
||
curr_for->body()->nstmts() == 1,
|
||
buildErrorMessage("Expected a single stmt in the loop body."));
|
||
curr_for = to<For>(curr_for->body()->front());
|
||
TORCH_INTERNAL_ASSERT(
|
||
curr_for,
|
||
buildErrorMessage("Expected the only child stmt to be a For loop."));
|
||
loops[i] = curr_for;
|
||
}
|
||
return loops;
|
||
}
|
||
|
||
bool LoopNest::flatten(const std::vector<ForPtr>& loops, ForPtr* flattened) {
|
||
if (loops.empty()) {
|
||
throw malformed_input("flatten attempted on empty set of loops");
|
||
}
|
||
BlockPtr p = to<Block>(loops[0]->get_parent());
|
||
if (!p) {
|
||
throw malformed_input("flatten attempted on loops with no parent");
|
||
}
|
||
|
||
if (loops.size() == 1) {
|
||
// This loop nest is already flattened.
|
||
*flattened = loops[0];
|
||
return false;
|
||
}
|
||
|
||
// Check if all the loops correspond to a perfect loopnest:
|
||
// * every loop except the inner-most should have only one stmt, the For.
|
||
// Do not flatten, otherwise.
|
||
// This check also ensures we do not flatten reduction loops.
|
||
for (size_t i = 0; i < loops.size() - 1; ++i) {
|
||
if ((loops[i]->body()->nstmts() != 1) ||
|
||
(loops[i]->body()->front() != loops[i + 1])) {
|
||
return false;
|
||
}
|
||
}
|
||
|
||
// Normalize the loops before flattening.
|
||
// We need to normalize them from inner-most to outer because once the outer
|
||
// loop is normalized, the given pointers to inner loops point to old code.
|
||
// For the same reason, we can't store the normalized inner loops until after
|
||
// the outer-most loop is normalized.
|
||
for (size_t i = 0; i < loops.size(); ++i) {
|
||
size_t idx = loops.size() - i - 1;
|
||
LoopNest::normalize(loops[idx]);
|
||
}
|
||
|
||
// 'normalized' points to the outer-most loop in the normalized loopnest.
|
||
// Collect all the normalized loops.
|
||
auto normalized_loops = getLoopStmtsInLoopNest(loops.front(), loops.size());
|
||
|
||
auto flat_var = alloc<Var>(
|
||
normalized_loops[0]->var()->name_hint() + "_flat",
|
||
normalized_loops[0]->var()->dtype());
|
||
VarMapping var_mapping;
|
||
ExprPtr stop = immLike(flat_var, 1);
|
||
for (size_t i = 0; i < normalized_loops.size(); ++i) {
|
||
size_t idx = normalized_loops.size() - i - 1;
|
||
auto curr_loop = normalized_loops[idx];
|
||
ExprPtr div = alloc<Div>(flat_var, stop);
|
||
ExprPtr sub_expr = idx == 0 ? div : alloc<Mod>(div, curr_loop->stop());
|
||
var_mapping.emplace_back(curr_loop->var(), sub_expr);
|
||
stop = alloc<Mul>(curr_loop->stop(), stop);
|
||
}
|
||
auto flattened_body =
|
||
Substitute(normalized_loops.back()->removeBody(), var_mapping);
|
||
|
||
normalized_loops.front()->set_var(flat_var);
|
||
normalized_loops.front()->set_start(immLike(stop, 0));
|
||
normalized_loops.front()->set_stop(stop);
|
||
normalized_loops.front()->set_body(flattened_body);
|
||
*flattened = normalized_loops.front();
|
||
return true;
|
||
}
|
||
|
||
bool LoopNest::flatten(const std::vector<ForPtr>& loops) {
|
||
ForPtr flattened;
|
||
return flatten(loops, &flattened);
|
||
}
|
||
|
||
void LoopNest::compressBuffer(const BufPtr& buf, const StmtPtr& stmt) {
|
||
// Loop iterations in NNC IR do not follow sequential semantics by default.
|
||
// In other words, the iterations of the loops could be executed in any
|
||
// random order without affecting correctness. This constraint in turn
|
||
// implies that there can’t be any *inter-iteration* dependences
|
||
// (or *loop-carried* dependences) in NNC loops. So, any NNC IR with such
|
||
// dependences is considered invalid.
|
||
//
|
||
// Given the constraint above, for any pair of accesses to a buffer (where
|
||
// at least one of the access is a write), the accesses must be
|
||
// loop-independent on the innermost loop containing the accesses as well as
|
||
// all the loops above it. So, any dimension that uses only those loop
|
||
// variables to access the given buffer could be optimized away.
|
||
//
|
||
// Algorithm:
|
||
// * Find all the accesses to the given buf. (A)
|
||
// * Find the parent common to all accesses in A. (P)
|
||
// * Collect all the loops above P. (L)
|
||
// * Collect all the loop variables corresponding to L. (LV)
|
||
// * For every access a in A:
|
||
// * For the index I in every dimension of a:
|
||
// * If the variables in I are all in LV, mark this dimension
|
||
// for deletion.
|
||
// * For every dimension that is marked for deletion in ALL accesses in A:
|
||
// * Update the buffer to set the size of that dimension to 1.
|
||
// * Update all accesses in A to set the index in that dimension to 0.
|
||
|
||
auto writes = WritesToBuf::find(stmt, buf);
|
||
auto reads = StmtsReadingBuf::find(stmt, buf);
|
||
|
||
// Find the parent common to all the buffer accesses.
|
||
BlockPtr parent = to<Block>(writes.front()->get_parent());
|
||
TORCH_INTERNAL_ASSERT(
|
||
parent,
|
||
buildErrorMessage(
|
||
"Expected parent stmt to be a non-null block in compressBuffer in the fuser."));
|
||
for (const auto& w : writes) {
|
||
parent = Block::getSharedParent(parent, w);
|
||
}
|
||
for (const auto& r : reads) {
|
||
parent = Block::getSharedParent(parent, r);
|
||
}
|
||
|
||
// Collect all the loops that are above the common parent.
|
||
auto loops = LoopNest::getEnclosingLoopNest(parent);
|
||
std::unordered_set<VarPtr> loop_vars;
|
||
for (const auto& l : loops) {
|
||
loop_vars.insert(l->var());
|
||
}
|
||
|
||
// TODO: Need to handle other Stmts / Exprs that read / write buffers.
|
||
auto stores = NodeFinder<Store>::find(stmt);
|
||
auto loads = NodeFinder<Load>::find(stmt);
|
||
|
||
// Vector to indicate which dimensions could be compressed away.
|
||
std::vector<bool> dims(buf->dims().size(), true);
|
||
auto check_indices = [&](const std::vector<ExprPtr>& indices) {
|
||
TORCH_INTERNAL_ASSERT(
|
||
indices.size() == dims.size(),
|
||
buildErrorMessage(
|
||
"Expected ranks to match in compressBuffer in the fuser."));
|
||
for (size_t i = 0; i < indices.size(); ++i) {
|
||
auto index_vars = NodeFinder<Var>::find(indices[i]);
|
||
for (const auto& iv : index_vars) {
|
||
if (loop_vars.count(iv) == 0) {
|
||
// A variable in this index is not in loop_vars.
|
||
// This implies that this dimension cannot be optimized away.
|
||
dims[i] = false;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
};
|
||
for (const auto& s : stores) {
|
||
if (s->buf() == buf) {
|
||
check_indices(s->indices());
|
||
}
|
||
}
|
||
for (const auto& l : loads) {
|
||
if (l->buf() == buf) {
|
||
check_indices(l->indices());
|
||
}
|
||
}
|
||
bool any_dim_to_compress = false;
|
||
for (auto d : dims) {
|
||
any_dim_to_compress |= d;
|
||
}
|
||
if (!any_dim_to_compress) {
|
||
return;
|
||
}
|
||
|
||
// Compress buffer by removing the marked dims.
|
||
std::vector<ExprPtr> new_dims(buf->dims());
|
||
for (size_t i = 0; i < dims.size(); ++i) {
|
||
if (dims[i]) {
|
||
new_dims[i] = immLike(buf->dims()[i], 1);
|
||
}
|
||
}
|
||
buf->set_dims(new_dims);
|
||
|
||
// Modify all access to reflect the removed dims.
|
||
auto get_new_indices = [&](const std::vector<ExprPtr>& indices) {
|
||
TORCH_INTERNAL_ASSERT(
|
||
indices.size() == dims.size(),
|
||
buildErrorMessage(
|
||
"Expected ranks to match in compressBuffer in the fuser."));
|
||
std::vector<ExprPtr> new_indices(indices);
|
||
for (size_t i = 0; i < dims.size(); ++i) {
|
||
if (dims[i]) {
|
||
new_indices[i] = immLike(indices[i], 0);
|
||
}
|
||
}
|
||
return new_indices;
|
||
};
|
||
for (const auto& s : stores) {
|
||
if (s->buf() == buf) {
|
||
s->set_indices(get_new_indices(s->indices()));
|
||
}
|
||
}
|
||
for (const auto& l : loads) {
|
||
if (l->buf() == buf) {
|
||
l->set_indices(get_new_indices(l->indices()));
|
||
}
|
||
}
|
||
}
|
||
|
||
void LoopNest::compressAllBuffers(const StmtPtr& stmt) {
|
||
for (const auto& buf : BufFinder::find(stmt)) {
|
||
compressBuffer(buf, stmt);
|
||
}
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::getLoopStmtsFor(const Tensor& t) const {
|
||
StmtPtr cur_stmt = getLoopBodyFor(t);
|
||
return getLoopStmtsFor(cur_stmt);
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::getLoopStmtsFor(const BufPtr& buf) const {
|
||
StmtPtr cur_stmt = getLoopBodyFor(buf);
|
||
return getLoopStmtsFor(cur_stmt);
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::getLoopStmtsFor(StmtPtr s) const {
|
||
std::vector<ForPtr> result;
|
||
|
||
while (s) {
|
||
if (auto loop = to<For>(s)) {
|
||
result.push_back(loop);
|
||
}
|
||
s = s->get_parent();
|
||
}
|
||
std::reverse(result.begin(), result.end());
|
||
return result;
|
||
}
|
||
|
||
StmtPtr LoopNest::getLoopBodyFor(const Tensor& t) const {
|
||
return getLoopBodyFor(t.buf());
|
||
}
|
||
|
||
StmtPtr LoopNest::getLoopBodyFor(BufPtr buf) const {
|
||
auto writes = WritesToBuf::find(root_stmt_, std::move(buf));
|
||
|
||
// special case for reduction Tensors, ignore the initializer if it's the only
|
||
// op:
|
||
if (writes.size() == 2) {
|
||
if (StorePtr s = to<Store>(writes.back())) {
|
||
if (ReduceOpPtr r = to<ReduceOp>(s->value())) {
|
||
return (StmtPtr)s;
|
||
}
|
||
}
|
||
}
|
||
|
||
StmtPtr res = nullptr;
|
||
for (const auto& s : writes) {
|
||
if (!res) {
|
||
res = s;
|
||
continue;
|
||
}
|
||
|
||
res = Block::getSharedParent(res, s);
|
||
}
|
||
|
||
return (StmtPtr)res;
|
||
}
|
||
|
||
ForPtr LoopNest::getParentLoop(const StmtPtr& st) {
|
||
if (st == nullptr) {
|
||
return nullptr;
|
||
}
|
||
auto par = st->get_parent();
|
||
if (auto f = to<For>(par)) {
|
||
return f;
|
||
}
|
||
return getParentLoop(par);
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::getEnclosingLoopNest(const StmtPtr& st) {
|
||
std::vector<ForPtr> loops;
|
||
auto f = getParentLoop(st);
|
||
while (f) {
|
||
loops.push_back(f);
|
||
f = getParentLoop(f);
|
||
}
|
||
std::reverse(loops.begin(), loops.end());
|
||
return loops;
|
||
}
|
||
|
||
std::vector<StmtPtr> LoopNest::getAllWritesToBuf(BufPtr buf) const {
|
||
return WritesToBuf::find(root_stmt_, std::move(buf));
|
||
}
|
||
|
||
std::vector<ForPtr> LoopNest::getAllInnermostLoopsWritingToBuf(
|
||
BufPtr buf) const {
|
||
auto writes = getAllWritesToBuf(std::move(buf));
|
||
std::vector<ForPtr> innermost_loops;
|
||
innermost_loops.reserve(writes.size());
|
||
for (const auto& w : writes) {
|
||
innermost_loops.push_back(LoopNest::getParentLoop(w));
|
||
}
|
||
return innermost_loops;
|
||
}
|
||
|
||
std::vector<std::vector<ForPtr>> LoopNest::getAllLoopNestsWritingToBuf(
|
||
BufPtr buf) const {
|
||
auto writes = getAllWritesToBuf(std::move(buf));
|
||
std::vector<std::vector<ForPtr>> loopnests;
|
||
loopnests.reserve(writes.size());
|
||
for (const auto& w : writes) {
|
||
loopnests.emplace_back(LoopNest::getEnclosingLoopNest(w));
|
||
}
|
||
return loopnests;
|
||
}
|
||
|
||
StmtPtr LoopNest::simplify() {
|
||
root_stmt_ = IRSimplifier::simplify(root_stmt_);
|
||
return root_stmt_;
|
||
}
|
||
|
||
StmtPtr FlattenIndexes(const StmtPtr& s) {
|
||
IndexFlattener idx_flattener;
|
||
return idx_flattener.flatten(s);
|
||
}
|
||
|
||
// Auxiliary class for rewriting we're doing in `compute_at`. See
|
||
// LoopNest::computeAt for more details.
|
||
class LoopComputeAtRewriter : public IRMutator {
|
||
public:
|
||
LoopComputeAtRewriter(
|
||
BufPtr buf,
|
||
BufPtr new_buf,
|
||
std::vector<ExprPtr> offsets)
|
||
: buf_(std::move(buf)),
|
||
new_buf_(std::move(new_buf)),
|
||
offsets_(std::move(offsets)) {}
|
||
|
||
private:
|
||
BufPtr buf_;
|
||
BufPtr new_buf_;
|
||
std::vector<ExprPtr> offsets_;
|
||
|
||
ExprPtr mutate(const LoadPtr& v) override {
|
||
if (v->buf() != buf_) {
|
||
return v;
|
||
}
|
||
std::vector<ExprPtr> new_indices(v->indices().size());
|
||
for (const auto i : c10::irange(v->indices().size())) {
|
||
new_indices[i] =
|
||
IRSimplifier::simplify(alloc<Sub>(v->indices()[i], offsets_[i]));
|
||
}
|
||
return alloc<Load>(v->dtype(), new_buf_, new_indices);
|
||
}
|
||
};
|
||
|
||
static StorePtr getStoreStmtOfProducer(const StmtPtr& s) {
|
||
if (StorePtr st = to<Store>(s)) {
|
||
return st;
|
||
}
|
||
if (BlockPtr b = to<Block>(s)) {
|
||
for (const StmtPtr& ss : *b) {
|
||
if (StorePtr st = to<Store>(ss)) {
|
||
return st;
|
||
}
|
||
}
|
||
}
|
||
return nullptr;
|
||
}
|
||
|
||
static std::vector<VarPtr> getOuterLoopIndexes(StmtPtr s) {
|
||
std::vector<VarPtr> res;
|
||
StmtPtr cur = std::move(s);
|
||
while (cur) {
|
||
if (auto l = to<For>(cur)) {
|
||
res.push_back(l->var());
|
||
}
|
||
cur = cur->get_parent();
|
||
}
|
||
return res;
|
||
}
|
||
|
||
class CacheReplacer : public IRMutator {
|
||
public:
|
||
CacheReplacer(BufPtr buffer, BufPtr cache, std::vector<ExprPtr>& offsets)
|
||
: buf_(std::move(buffer)), cache_(std::move(cache)), offsets_(offsets) {}
|
||
|
||
private:
|
||
ExprPtr mutate(const LoadPtr& v) override {
|
||
BufPtr buf = v->buf();
|
||
if (buf != buf_) {
|
||
return IRMutator::mutate(v);
|
||
}
|
||
|
||
// Map indices to call-parameters.
|
||
std::vector<ExprPtr> newIndices;
|
||
TORCH_INTERNAL_ASSERT(
|
||
offsets_.size() == v->indices().size(),
|
||
buildErrorMessage(
|
||
"Expected ranks to match in CacheReplacer in the fuser."));
|
||
for (size_t i = 0; i < v->indices().size(); ++i) {
|
||
ExprPtr index = v->indices()[i]->accept_mutator(this);
|
||
ExprPtr offset = offsets_[i];
|
||
ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
|
||
newIndices.push_back(sub);
|
||
}
|
||
v->set_buf(cache_);
|
||
v->set_indices(newIndices);
|
||
return v;
|
||
}
|
||
|
||
StmtPtr mutate(const StorePtr& v) override {
|
||
BufPtr buf = v->buf();
|
||
if (buf != buf_) {
|
||
return IRMutator::mutate(v);
|
||
}
|
||
|
||
ExprPtr newValue = v->value()->accept_mutator(this);
|
||
|
||
// Map indices to call-parameters.
|
||
std::vector<ExprPtr> newIndices;
|
||
TORCH_INTERNAL_ASSERT(
|
||
offsets_.size() == v->indices().size(),
|
||
buildErrorMessage(
|
||
"Expected ranks to match in CacheReplacer in the fuser."));
|
||
for (size_t i = 0; i < v->indices().size(); ++i) {
|
||
ExprPtr index = v->indices()[i]->accept_mutator(this);
|
||
ExprPtr offset = offsets_[i];
|
||
ExprPtr sub = IRSimplifier::simplify(alloc<Sub>(index, offset));
|
||
newIndices.push_back(sub);
|
||
}
|
||
v->set_buf(cache_);
|
||
v->set_indices(newIndices);
|
||
v->set_value(newValue);
|
||
return v;
|
||
}
|
||
|
||
BufPtr buf_;
|
||
BufPtr cache_;
|
||
std::vector<ExprPtr>& offsets_;
|
||
};
|
||
|
||
LoopNest::AccessResult LoopNest::cacheAccesses(
|
||
const BufPtr& producer,
|
||
const std::string& name,
|
||
const StmtPtr& consumer) {
|
||
ReduceOpPtr reduceOp{nullptr};
|
||
auto stores = NodeFinder<Store>::find(consumer);
|
||
for (const auto& store : stores) {
|
||
if (auto ro = to<ReduceOp>(store->value())) {
|
||
if (store->buf() != producer) {
|
||
continue;
|
||
}
|
||
|
||
if (reduceOp) {
|
||
throw std::runtime_error(
|
||
"can only cache accesses used by at most a single reduceOp");
|
||
return {nullptr, nullptr};
|
||
}
|
||
|
||
reduceOp = ro;
|
||
}
|
||
}
|
||
|
||
// Check bounds but don't care about AccessKind.
|
||
auto consumer_bounds_info = inferBounds(consumer, false);
|
||
auto bounds_it = consumer_bounds_info.find(producer);
|
||
if (bounds_it == consumer_bounds_info.end()) {
|
||
throw std::runtime_error("consumer does not use the Tensor produced");
|
||
return {nullptr, nullptr};
|
||
}
|
||
|
||
TORCH_INTERNAL_ASSERT(
|
||
bounds_it->second.size() == 1,
|
||
buildErrorMessage(
|
||
"Unexpected number of bound info entries in cacheAccesses in the fuser."));
|
||
TensorAccessBoundsInfo& info = bounds_it->second[0];
|
||
bool hasReads = info.kind == kLoad || info.kind == kMutate;
|
||
bool hasWrites = info.kind == kStore || info.kind == kMutate;
|
||
|
||
std::vector<std::string> var_names = {"i", "j", "k", "l", "m", "n", "o", "p"};
|
||
std::vector<ExprPtr> tmp_dims;
|
||
std::vector<VarPtr> new_loop_vars;
|
||
std::vector<ExprPtr> new_loop_vars_expr;
|
||
|
||
// Determine the size of the cache, and create a loop var for each dimension.
|
||
for (size_t i = 0; i < info.start.size(); ++i) {
|
||
ExprPtr dim = IRSimplifier::simplify(alloc<Add>(
|
||
alloc<Sub>(info.stop[i], info.start[i]), immLike(info.stop[i], 1)));
|
||
|
||
tmp_dims.push_back(dim);
|
||
|
||
new_loop_vars.push_back(
|
||
alloc<Var>(var_names[i % var_names.size()], info.stop[i]->dtype()));
|
||
new_loop_vars_expr.push_back(new_loop_vars[i]);
|
||
}
|
||
|
||
// Create the var.
|
||
BufPtr tmp_buf =
|
||
alloc<Buf>(alloc<Var>(name, kHandle), tmp_dims, producer->dtype());
|
||
|
||
// determine the offsets for calls into the cache based off the loop start of
|
||
// each axis.
|
||
std::vector<ExprPtr> tmp_params;
|
||
for (size_t i = 0; i < new_loop_vars.size(); ++i) {
|
||
tmp_params.push_back(alloc<Add>(new_loop_vars[i], info.start[i]));
|
||
}
|
||
|
||
// Replace accesses to the producer in the consumer with the cache.
|
||
CacheReplacer replacer(producer, tmp_buf, info.start);
|
||
consumer->accept_mutator(&replacer);
|
||
|
||
// replace the old consumer with the replaced consumer.
|
||
BlockPtr consumer_block = to<Block>(consumer);
|
||
BlockPtr parent_block = to<Block>(consumer->get_parent());
|
||
// if the consumer is a block, we should mutate it in place.
|
||
bool is_block = consumer_block != nullptr;
|
||
|
||
// If there's a reduction and we are operating on the reduce axis, we need to
|
||
// initialize the cache with 0s. Also, we can't just write the result straight
|
||
// back to the original buffer, since after parallelism the writes will race.
|
||
// Instead we need to create a new ReduceOp.
|
||
bool on_reduce_axis = false;
|
||
if (reduceOp) {
|
||
std::set<VarPtr> reduce_args(
|
||
reduceOp->reduce_args().begin(), reduceOp->reduce_args().end());
|
||
std::set<VarPtr> enclosing_vars;
|
||
for (const auto& enclosing_for_stmt : NodeFinder<For>::find(consumer)) {
|
||
enclosing_vars.insert(enclosing_for_stmt->var());
|
||
}
|
||
for (const auto& reduce_arg : reduce_args) {
|
||
if (enclosing_vars.find(reduce_arg) == enclosing_vars.end()) {
|
||
on_reduce_axis = true;
|
||
}
|
||
}
|
||
}
|
||
if (reduceOp && on_reduce_axis) {
|
||
// reduceOp means we had both loads and stores.
|
||
|
||
// Init cache to 0.
|
||
StmtPtr tmp_init = alloc<Store>(
|
||
tmp_buf, new_loop_vars_expr, getImmediateByType(tmp_buf->dtype(), 0));
|
||
|
||
for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
|
||
tmp_init = alloc<For>(
|
||
new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_init);
|
||
}
|
||
|
||
if (is_block) {
|
||
consumer_block->prepend_stmt(tmp_init);
|
||
} else {
|
||
parent_block->insert_stmt_before(tmp_init, consumer);
|
||
}
|
||
|
||
// Reduce back to the original buffer:
|
||
StmtPtr tmp_store = alloc<Store>(
|
||
producer,
|
||
tmp_params,
|
||
reduceOp->reducer()(
|
||
producer,
|
||
alloc<Load>(tmp_buf, new_loop_vars_expr),
|
||
tmp_params,
|
||
{}));
|
||
|
||
for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
|
||
tmp_store = alloc<For>(
|
||
new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store);
|
||
}
|
||
|
||
if (is_block) {
|
||
consumer_block->append_stmt(tmp_store);
|
||
} else {
|
||
parent_block->insert_stmt_after(tmp_store, consumer);
|
||
}
|
||
|
||
return std::make_pair(tmp_buf, consumer);
|
||
}
|
||
|
||
if (hasReads) {
|
||
// Fill the cache with values from the consumer.
|
||
StmtPtr tmp_store = alloc<Store>(
|
||
tmp_buf, new_loop_vars_expr, alloc<Load>(producer, tmp_params));
|
||
|
||
for (int64_t i = new_loop_vars.size() - 1; i >= 0; --i) {
|
||
tmp_store = alloc<For>(
|
||
new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store);
|
||
}
|
||
|
||
if (is_block) {
|
||
consumer_block->prepend_stmt(tmp_store);
|
||
} else {
|
||
parent_block->insert_stmt_before(tmp_store, consumer);
|
||
}
|
||
}
|
||
|
||
if (hasWrites) {
|
||
// sync the cache back to the producer buf.
|
||
StmtPtr tmp_store = alloc<Store>(
|
||
producer, tmp_params, alloc<Load>(tmp_buf, new_loop_vars_expr));
|
||
|
||
for (int64_t i = static_cast<int64_t>(new_loop_vars.size()) - 1; i >= 0;
|
||
--i) {
|
||
tmp_store = alloc<For>(
|
||
new_loop_vars[i], immLike(tmp_dims[i], 0), tmp_dims[i], tmp_store);
|
||
}
|
||
|
||
if (is_block) {
|
||
consumer_block->append_stmt(tmp_store);
|
||
} else {
|
||
parent_block->insert_stmt_after(tmp_store, consumer);
|
||
}
|
||
}
|
||
|
||
return std::make_pair(tmp_buf, consumer);
|
||
}
|
||
|
||
/*
|
||
* WHAT COMPUTE_AT DOES
|
||
* ====================
|
||
*
|
||
* Suppose we have two loops:
|
||
*
|
||
* for i in 0..100:
|
||
* for j in 0..200:
|
||
* A[i,j] = sin(i*j)
|
||
* for i in 0..100:
|
||
* for j in 0..199:
|
||
* B[i,j] = A[i,j] + A[i, j+1]
|
||
*
|
||
* If we compute these loops as is, we would have to allocate two buffers:
|
||
* 100x200 for A and 100x199 for B. To decrease the memory usage one can use
|
||
* compute_inline primitive, which would result in the following:
|
||
*
|
||
* for i in 0..100:
|
||
* for j in 0..199:
|
||
* B[i,j] = sin(i*j) + sin(i*(j+1))
|
||
*
|
||
* We now need only one buffer - 100x199 for B. However, we're now doing some
|
||
* redundant computations: we're calling `sin` twice as much as in the first
|
||
* version.
|
||
*
|
||
* Ultimately, we nede to choose at what point we prefer to compute values of
|
||
* A[i,j] - we can do it in the very beginning for the entire buffer A (the
|
||
* first option) or compute it on the fly when we compute B (the second option).
|
||
* There are also options in between those two: we can compute a part of B which
|
||
* is required for a computation of part of B, e.g. for a single row of B. The
|
||
* code would then look like:
|
||
*
|
||
* for i in 0..100:
|
||
* for j in 0..200:
|
||
* A[j] = sin(i*j)
|
||
* for j in 0..199:
|
||
* B[i,j] = A[j] + A[j+1]
|
||
*
|
||
* In this case we're only using 1x200 for A, and we're avoiding redundant
|
||
* computations.
|
||
*
|
||
* The purpose of `compute_at` is to achieve exactly this transformation.
|
||
*
|
||
* compute_at requires to specify What to compute and Where to compute: in our
|
||
* example we would call compute_at(What=`A[i,j] = sin(i*j)`, Where=`for i in
|
||
* 0..100`).
|
||
*
|
||
* More info about compute_at could be found in Halide's tutorials:
|
||
* https://halide-lang.org/tutorials/tutorial_lesson_08_scheduling_2.html
|
||
*
|
||
* HOW COMPUTE_AT WORKS
|
||
* ====================
|
||
*
|
||
* The most important part of compute_at is bounds inference: we need to figure
|
||
* out what part of the used tensors we need to compute when we move the
|
||
* computation to a new scope. In the example above, we need bounds inference to
|
||
* tell us that in order to compute A at each iteration of the outer loop, we
|
||
* need to compute A within indices [i:i+1,0:200].
|
||
*
|
||
* This info allows us to conclude that we need a temp buffer of size 1x200.
|
||
*
|
||
* Once this is known we need to insert statements for allocation and freeing
|
||
* the temporary buffer and copy the original computation to fill the temp
|
||
* buffer with proper values. When we copy the computation we also must rewrite
|
||
* indices used in it: old indices are referring to the old loop and are not
|
||
* valid in the new loop.
|
||
*
|
||
* To easier follow the logic, let's examine an example. Suppose we start from
|
||
* the following loop nest:
|
||
* for py in 0..100:
|
||
* for px in 0..100:
|
||
* producer[py,px] = py*px
|
||
* for cy in 0..100:
|
||
* for cx in 0..100:
|
||
* consumer[cy,cx] = producer[cy,cx]
|
||
*
|
||
* And then we're running `compute_at(producer, cy)`.
|
||
*
|
||
* What we would like to get is the following loop nest:
|
||
* for py in 0..100:
|
||
* for px in 0..100:
|
||
* producer[py,px] = py*px
|
||
* for cy in 0..100:
|
||
* Allocate(temp, {1, 100})
|
||
* for ty in 0..1:
|
||
* for tx in 0..100:
|
||
* temp[ty,tx] = (ty+cy)*(tx+0)
|
||
* for cx in 0..100:
|
||
* consumer[cy,cx] = temp[0,cx]
|
||
* Free(temp)
|
||
*
|
||
* NB: this loop nest can and should be simplified (e.g. the producer loop can
|
||
* be removed since its result is no longer used), but this clean-up
|
||
* optimization is performed separately (currently, not performed at all).
|
||
*
|
||
* If we examine the final loop nest, we can identify that the following steps
|
||
* needs to be performed:
|
||
* - Bounds inference needs to tell us that we need a 1x100 buffer for temp.
|
||
* - Allocate and Free statements for this buffer need to be inserted to the
|
||
* loop.
|
||
* - A new loop-nest should be inserted to the loop CY for computing `temp`
|
||
* and it should replicate the loopnest of producer (PY,PX loops). The indices
|
||
* in the loop body need to be offset by (cy, 0) - the offsets come from
|
||
* bounds inference too.
|
||
* - The computation of `consumer` needs to be rewritten so that it uses
|
||
* `temp` instead of `producer`. The indices in the corresponding accesses
|
||
* also need to be offset.
|
||
*/
|
||
void LoopNest::computeAt(const StmtPtr& s, const ForPtr& f) {
|
||
StorePtr st = getStoreStmtOfProducer(s);
|
||
if (!st) {
|
||
return;
|
||
}
|
||
|
||
// Infer bounds info for all accesses that we make in the loop
|
||
auto loop_bounds_info = inferBounds(f->body());
|
||
|
||
// bounds_it holds bounds info for the store we're trying to move to
|
||
// the loop. If its result isn't accessed in the loop at all - do nothing and
|
||
// exit early.
|
||
auto bounds_it = loop_bounds_info.find(st->buf());
|
||
if (bounds_it == loop_bounds_info.end()) {
|
||
return;
|
||
}
|
||
|
||
// Compute dimensions of the temp buffer we would need to allocate
|
||
std::vector<ExprPtr> dims = getBoundExtents(bounds_it->second);
|
||
|
||
// TODO: Use name-hint of the producer instead of "temp"
|
||
BufPtr temp_buf = alloc<Buf>("temp", dims, st->value()->dtype());
|
||
|
||
// Generate index variables for 'temp'
|
||
std::vector<ExprPtr> temp_indices(dims.size());
|
||
for (const auto i : c10::irange(dims.size())) {
|
||
// TODO: Use name-hint of the producer indices instead of 'idx'
|
||
temp_indices[i] =
|
||
alloc<Var>(std::string("idx") + std::to_string(i), dims[i]->dtype());
|
||
}
|
||
|
||
// Prepare substitute rules for constructing the temp statement from the prod
|
||
// statement
|
||
// TODO: Instead of going up the loop nest we should go through the indices in
|
||
// the original tensor expression. The loops in the nest might've been
|
||
// modified (e.g. split or merged) so that the loop indices no longer
|
||
// correspond to the indices of the original expression and even their number
|
||
// might be different. In that case, the loop below would crash.
|
||
std::vector<VarPtr> prod_indices = getOuterLoopIndexes(s);
|
||
std::vector<std::pair<VarPtr, ExprPtr>> rewrite_indices_map;
|
||
std::vector<ExprPtr> offsets;
|
||
for (const TensorAccessBoundsInfo& p : bounds_it->second) {
|
||
for (const auto i : c10::irange(p.start.size())) {
|
||
if (offsets.size() <= i) {
|
||
offsets.push_back(p.start[i]);
|
||
} else {
|
||
offsets[i] =
|
||
IRSimplifier::simplify(alloc<Min>(offsets[i], p.start[i], true));
|
||
}
|
||
}
|
||
}
|
||
|
||
for (const auto i : c10::irange(prod_indices.size())) {
|
||
rewrite_indices_map.emplace_back(
|
||
prod_indices[i], alloc<Add>(temp_indices[i], offsets[i]));
|
||
}
|
||
|
||
// Construct the temp statement
|
||
StmtPtr bd = alloc<Store>(
|
||
temp_buf,
|
||
temp_indices,
|
||
SubstituteInClone(st->value(), rewrite_indices_map));
|
||
|
||
// Construct the loop nest for the temp computation
|
||
for (const auto i : c10::irange(dims.size())) {
|
||
// We're creating loops from innermost to outermost, so we need to access
|
||
// dimensions in reversed order.
|
||
size_t dim_idx = dims.size() - 1 - i;
|
||
bd = alloc<For>(
|
||
to<Var>(temp_indices[dim_idx]),
|
||
immLike(dims[dim_idx], 0),
|
||
dims[dim_idx],
|
||
bd);
|
||
}
|
||
|
||
// Add constructed stmts to the consumer loop
|
||
f->body()->prepend_stmt(bd);
|
||
|
||
// Rewrite accesses to producer in consumer with accesses to temp
|
||
LoopComputeAtRewriter lr(st->buf(), temp_buf, offsets);
|
||
StmtPtr new_f = f->accept_mutator(&lr);
|
||
if (f != new_f) {
|
||
BlockPtr bb = to<Block>(f->get_parent());
|
||
bb->replace_stmt(f, new_f);
|
||
}
|
||
}
|
||
|
||
class RfactorStoreRewriter : public IRMutator {
|
||
public:
|
||
RfactorStoreRewriter(
|
||
BufPtr old_buf,
|
||
const std::vector<ExprPtr>& old_indices,
|
||
BufPtr new_buf,
|
||
VarPtr reduction_var)
|
||
: old_buf_(std::move(old_buf)),
|
||
old_indices_(old_indices),
|
||
new_buf_(std::move(new_buf)),
|
||
reduction_var_(std::move(reduction_var)),
|
||
new_indices_(old_indices) {
|
||
new_indices_.push_back(reduction_var_);
|
||
}
|
||
|
||
ExprPtr mutate(const LoadPtr& v) override {
|
||
if (v->buf() != old_buf_) {
|
||
return IRMutator::mutate(v);
|
||
}
|
||
|
||
TORCH_INTERNAL_ASSERT(
|
||
old_indices_.size() == v->indices().size(),
|
||
buildErrorMessage(
|
||
"Expected ranks to match in RfactorStoreRewriter in the fuser."));
|
||
|
||
bool equal_indices = true;
|
||
for (size_t i = 0; i < v->indices().size(); ++i) {
|
||
if (!exprEquals(v->indices()[i], old_indices_[i])) {
|
||
equal_indices = false;
|
||
break;
|
||
}
|
||
}
|
||
if (!equal_indices) {
|
||
return IRMutator::mutate(v);
|
||
}
|
||
|
||
return alloc<Load>(new_buf_, new_indices_);
|
||
}
|
||
|
||
ExprPtr mutate(const ReduceOpPtr& v) override {
|
||
ExprPtr body_new = v->body()->accept_mutator(this);
|
||
|
||
std::vector<VarPtr> new_reduce_args;
|
||
for (const auto& r : v->reduce_args()) {
|
||
if (r != reduction_var_) {
|
||
new_reduce_args.push_back(r);
|
||
}
|
||
}
|
||
|
||
return alloc<ReduceOp>(body_new, new_reduce_args, v->reducer());
|
||
}
|
||
|
||
StmtPtr mutate(const StorePtr& v) override {
|
||
if (v->buf() != old_buf_) {
|
||
return IRMutator::mutate(v);
|
||
}
|
||
|
||
TORCH_INTERNAL_ASSERT(
|
||
old_indices_.size() == v->indices().size(),
|
||
buildErrorMessage(
|
||
"Expected ranks to match in RfactorStoreRewriter in the fuser."));
|
||
|
||
bool equal_indices = true;
|
||
for (size_t i = 0; i < v->indices().size(); ++i) {
|
||
if (!exprEquals(v->indices()[i], old_indices_[i])) {
|
||
equal_indices = false;
|
||
break;
|
||
}
|
||
}
|
||
if (!equal_indices) {
|
||
return IRMutator::mutate(v);
|
||
}
|
||
|
||
ExprPtr new_value = v->value()->accept_mutator(this);
|
||
return alloc<Store>(new_buf_, new_indices_, new_value);
|
||
}
|
||
|
||
private:
|
||
BufPtr old_buf_;
|
||
const std::vector<ExprPtr>& old_indices_;
|
||
BufPtr new_buf_;
|
||
VarPtr reduction_var_;
|
||
std::vector<ExprPtr> new_indices_;
|
||
};
|
||
|
||
bool LoopNest::rfactor(const StmtPtr& st, const ForPtr& target_for) {
|
||
BufPtr tmp_buf = nullptr;
|
||
return rfactor(st, target_for, &tmp_buf);
|
||
}
|
||
|
||
bool LoopNest::rfactor(
|
||
const StmtPtr& st,
|
||
const ForPtr& outer_reduction_for,
|
||
BufPtr* rfac_buf_ptr) {
|
||
StorePtr reduction_store = to<Store>(st);
|
||
ReduceOpPtr reduce_op = to<ReduceOp>(reduction_store->value());
|
||
if (!reduce_op) {
|
||
// Not a reduction store
|
||
return false;
|
||
}
|
||
|
||
auto orig_buf = reduction_store->buf();
|
||
auto orig_buf_indices = reduction_store->indices();
|
||
VarPtr reduction_var = outer_reduction_for->var();
|
||
|
||
std::set<VarPtr> reduce_args = {
|
||
reduce_op->reduce_args().begin(), reduce_op->reduce_args().end()};
|
||
|
||
if (reduce_args.size() < 2) {
|
||
// Not enough reduction axis to do rfactor
|
||
return false;
|
||
}
|
||
|
||
// Verify that outer_reduction_for is a perfect loop nest with all loops being
|
||
// reductions
|
||
StmtPtr cur = outer_reduction_for;
|
||
while (ForPtr cur_for = to<For>(cur)) {
|
||
if (!reduce_args.count(cur_for->var())) {
|
||
// output axis inside outer_reduction_for are not allowed
|
||
return false;
|
||
}
|
||
reduce_args.erase(cur_for->var());
|
||
|
||
BlockPtr b = cur_for->body();
|
||
if (b->nstmts() != 1) {
|
||
return false;
|
||
}
|
||
cur = b->stmts().front();
|
||
}
|
||
if (cur != st) {
|
||
// The reduction store is not a single stmt in the innermost loop - bail in
|
||
// that case
|
||
return false;
|
||
}
|
||
if (!reduce_args.empty()) {
|
||
// This is not the outermost reduction axis
|
||
return false;
|
||
}
|
||
|
||
// assert: reduce_axis match loop vars from outer_reduction_for and inside
|
||
// assert: no other stmts in outer_reduction_for or its child loops
|
||
|
||
std::vector<ExprPtr> rfac_dims = orig_buf->dims();
|
||
ExprPtr extra_dim = IRSimplifier::simplify(
|
||
alloc<Sub>(outer_reduction_for->stop(), outer_reduction_for->start()));
|
||
rfac_dims.push_back(extra_dim);
|
||
ExprPtr rfac_init =
|
||
alloc<Cast>(reduce_op->dtype(), reduce_op->reducer().initializer());
|
||
|
||
*rfac_buf_ptr = alloc<Buf>(
|
||
orig_buf->name_hint() + "_rfac",
|
||
rfac_dims,
|
||
reduce_op->dtype(),
|
||
rfac_init);
|
||
BufPtr rfac_buf = *rfac_buf_ptr;
|
||
|
||
// Rewrite the original reduction store to use the temporary rfac buffer:
|
||
// 1) X[*indexes] --> T[*indexes + {reduction_var}]
|
||
// 2) reduce_axis -= {reduction_var}
|
||
RfactorStoreRewriter rfac_rewriter(
|
||
orig_buf, orig_buf_indices, rfac_buf, reduction_var);
|
||
to<Block>(st->get_parent())
|
||
->replace_stmt(st, st->accept_mutator(&rfac_rewriter));
|
||
|
||
// Insert a store for the final reduction over the temp buffer into the
|
||
// original buffer:
|
||
// X[*indexes] = ReduceOp(X[*indexes] + T[*indexes + {reduction_var}],
|
||
// reduce_axis={reduction_var})
|
||
BlockPtr b = outer_reduction_for->body();
|
||
TORCH_INTERNAL_ASSERT(
|
||
b->nstmts() == 1,
|
||
buildErrorMessage(
|
||
"Expected to have a single stmt in the block in rfactor transformation in the fuser."));
|
||
StmtPtr first_reduction_loop = b->stmts().front();
|
||
auto rfac_buf_indices = orig_buf_indices;
|
||
rfac_buf_indices.emplace_back(reduction_var);
|
||
|
||
ExprPtr final_reduce_load = alloc<Load>(rfac_buf, rfac_buf_indices);
|
||
outer_reduction_for->body()->insert_stmt_after(
|
||
alloc<Store>(
|
||
orig_buf,
|
||
orig_buf_indices,
|
||
reduce_op->reducer()(
|
||
orig_buf, final_reduce_load, orig_buf_indices, {reduction_var})),
|
||
first_reduction_loop);
|
||
|
||
// Insert an initialization store for the temp buffer:
|
||
// T[a,b,c] = init
|
||
outer_reduction_for->body()->insert_stmt_before(
|
||
alloc<Store>(rfac_buf, rfac_buf_indices, rfac_init),
|
||
first_reduction_loop);
|
||
return true;
|
||
}
|
||
|
||
} // namespace torch::jit::tensorexpr
|