[nnc] Removed const from all fields in IR. (#62336)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62336

This PR was generated by removing `const` for all types of nodes in NNC IR, and fixing compilation errors that were the result of this change.

This is the first step in making all NNC mutations in-place.

Test Plan: Imported from OSS

Reviewed By: iramazanli

Differential Revision: D30049829

Pulled By: navahgar

fbshipit-source-id: ed14e2d2ca0559ffc0b92ac371f405579c85dd63
This commit is contained in:
Raghavan Raman
2021-08-03 11:43:07 -07:00
committed by Facebook GitHub Bot
parent 474d7ec43b
commit 59dd12042e
60 changed files with 2364 additions and 2484 deletions

View File

@ -14,8 +14,8 @@ using namespace torch::jit::tensorexpr;
TEST(CppPrinter, AllocateOnStackThenFree) {
KernelScope kernel_scope;
std::vector<const Expr*> dims = {new IntImm(2), new IntImm(3)};
const Buf* buf = new Buf("x", dims, kInt);
std::vector<Expr*> dims = {new IntImm(2), new IntImm(3)};
Buf* buf = new Buf("x", dims, kInt);
Allocate* alloc = new Allocate(buf);
Free* free = new Free(buf);
Block* block = Block::make({alloc, free});
@ -33,9 +33,8 @@ TEST(CppPrinter, AllocateOnStackThenFree) {
TEST(CppPrinter, AllocateOnHeapThenFree) {
KernelScope kernel_scope;
std::vector<const Expr*> dims = {
new IntImm(20), new IntImm(50), new IntImm(3)};
const Buf* buf = new Buf("y", dims, kLong);
std::vector<Expr*> dims = {new IntImm(20), new IntImm(50), new IntImm(3)};
Buf* buf = new Buf("y", dims, kLong);
Allocate* alloc = new Allocate(buf);
Free* free = new Free(buf);
Block* block = Block::make({alloc, free});

View File

@ -705,7 +705,7 @@ TEST(Cuda, SharedMemReduce_1_CUDA) {
VarHandle n("n", kInt);
std::vector<Stmt*> block;
std::vector<const Expr*> dims;
std::vector<Expr*> dims;
dims.push_back(ExprHandle(N).node());
BufHandle c{new Buf("c", dims, kFloat)};
{

View File

@ -313,13 +313,13 @@ TEST(Expr, IntrinsicsDtypes) {
TEST(Expr, Substitute01) {
KernelScope kernel_scope;
const Var* x = new Var("x", kFloat);
const Var* y = new Var("y", kFloat);
const Expr* e = new Mul(new Sub(x, new FloatImm(1.0f)), new Add(x, y));
Var* x = new Var("x", kFloat);
Var* y = new Var("y", kFloat);
Expr* e = new Mul(new Sub(x, new FloatImm(1.0f)), new Add(x, y));
const Var* z = new Var("z", kFloat);
const Expr* e2 = Substitute(e, {{x, new Add(z, new FloatImm(5.0f))}});
const Expr* e2_ref = new Mul(
Var* z = new Var("z", kFloat);
Expr* e2 = Substitute(e, {{x, new Add(z, new FloatImm(5.0f))}});
Expr* e2_ref = new Mul(
new Sub(new Add(z, new FloatImm(5.0f)), new FloatImm(1.0f)),
new Add(new Add(z, new FloatImm(5.0f)), y));
std::ostringstream oss;
@ -663,7 +663,7 @@ void testStmtClone() {
// original statement hasn't changed while the cloned one has.
Stmt* body_addition = a_buf.store({index}, 33);
Block* cloned_body =
static_cast<Block*>(static_cast<const For*>(cloned_loop)->body());
static_cast<Block*>(static_cast<For*>(cloned_loop)->body());
cloned_body->append_stmt(body_addition);
std::vector<int> orig_loop_results_after_mutation(N);

View File

@ -18,8 +18,8 @@ using namespace torch::jit::tensorexpr;
TEST(IRVerifier, BitwiseOps) {
KernelScope kernel_scope;
const Var* X = new Var("x", kInt);
const Var* Y = new Var("y", kFloat);
Var* X = new Var("x", kInt);
Var* Y = new Var("y", kFloat);
{
auto a = new And(X, Y);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
@ -49,8 +49,8 @@ TEST(IRVerifier, BitwiseOps) {
TEST(IRVerifier, CompareSelect) {
KernelScope kernel_scope;
const Expr* X = new IntImm(1);
const Expr* Y = new FloatImm(3.14f);
Expr* X = new IntImm(1);
Expr* Y = new FloatImm(3.14f);
{
auto a = new CompareSelect(X, X, X, Y, kEQ);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
@ -65,8 +65,8 @@ TEST(IRVerifier, CompareSelect) {
TEST(IRVerifier, Ramp) {
KernelScope kernel_scope;
const Var* I = new Var("i", kInt);
const Var* J = new Var("j", kFloat);
Var* I = new Var("i", kInt);
Var* J = new Var("j", kFloat);
{
auto a = new Ramp(I, J, 4);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
@ -76,10 +76,10 @@ TEST(IRVerifier, Ramp) {
TEST(IRVerifier, Load) {
KernelScope kernel_scope;
const Var* I = new Var("i", kInt);
const Var* J = new Var("j", kLong);
const Var* K = new Var("k", kFloat);
const Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
Var* I = new Var("i", kInt);
Var* J = new Var("j", kLong);
Var* K = new Var("k", kFloat);
Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
{
// Indices with different int dtypes (kInt, kLong) are ok
auto a = new Load(B, {I, J});
@ -103,9 +103,9 @@ TEST(IRVerifier, Load) {
TEST(IRVerifier, IfThenElse) {
KernelScope kernel_scope;
const Var* I = new Var("i", kInt);
const Var* J = new Var("j", kLong);
const Var* K = new Var("k", kFloat);
Var* I = new Var("i", kInt);
Var* J = new Var("j", kLong);
Var* K = new Var("k", kFloat);
{
// Condition must be integral
auto a = new IfThenElse(K, I, I);
@ -128,8 +128,8 @@ TEST(IRVerifier, IfThenElse) {
TEST(IRVerifier, For) {
KernelScope kernel_scope;
const Var* I = new Var("i", kInt);
const Var* J = new Var("j", kInt);
Var* I = new Var("i", kInt);
Var* J = new Var("j", kInt);
Stmt* body = new Block({});
{
// Can't have nullptr as a Var
@ -141,8 +141,8 @@ TEST(IRVerifier, For) {
TEST(IRVerifier, Block) {
KernelScope kernel_scope;
const Var* I = new Var("i", kInt);
const Buf* B = new Buf("B", {new IntImm(10)}, kInt);
Var* I = new Var("i", kInt);
Buf* B = new Buf("B", {new IntImm(10)}, kInt);
{
Stmt* store = new Store(B, {I}, I);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
@ -158,10 +158,10 @@ TEST(IRVerifier, Block) {
TEST(IRVerifier, Store) {
KernelScope kernel_scope;
const Var* I = new Var("i", kInt);
const Var* J = new Var("j", kLong);
const Var* K = new Var("k", kFloat);
const Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
Var* I = new Var("i", kInt);
Var* J = new Var("j", kLong);
Var* K = new Var("k", kFloat);
Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
{
// Indices with different int dtypes (kInt, kLong) are ok
auto a = new Store(B, {I, J}, K);

View File

@ -2246,7 +2246,7 @@ class LoopOrderHelper : public IRVisitor {
}
// NOLINTNEXTLINE(cppcoreguidelines-explicit--functions,modernize-use-override)
void visit(const For* v) {
void visit(For* v) {
ordering << v->var()->name_hint() << ",";
IRVisitor::visit(v);
}

View File

@ -815,9 +815,9 @@ TEST(MemDependency, MemDependencyCheckerLoopBounds) {
// much.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 10);
const Var* aVar = a.node()->base_handle();
const Var* bVar = b.node()->base_handle();
const Var* cVar = c.node()->base_handle();
Var* aVar = a.node()->base_handle();
Var* bVar = b.node()->base_handle();
Var* cVar = c.node()->base_handle();
// The first access is the input A.
ASSERT_EQ(history[0]->type(), AccessType::Input);
@ -989,8 +989,8 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) {
// Now let's look at the bounds of each access.
auto history = analyzer.getHistory();
ASSERT_EQ(history.size(), 12);
const Var* aVar = a.node()->base_handle();
const Var* bVar = b.node()->base_handle();
Var* aVar = a.node()->base_handle();
Var* bVar = b.node()->base_handle();
// The first access is the input A.
ASSERT_EQ(history[0]->type(), AccessType::Input);
@ -3119,7 +3119,7 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
history_before[i]->bounds(), history_after[i]->bounds()));
} else {
ASSERT_EQ(history_after[i]->bounds().size(), 1);
const Expr* flat_bounds = new IntImm(1);
Expr* flat_bounds = new IntImm(1);
for (auto& b : history_before[i]->bounds()) {
flat_bounds = new Mul(flat_bounds, new Add(b.end, new IntImm(1)));
@ -3129,7 +3129,7 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
}
flat_bounds = IRSimplifier::simplify(flat_bounds);
const Expr* after_bounds = IRSimplifier::simplify(
Expr* after_bounds = IRSimplifier::simplify(
new Add(history_after[i]->bounds()[0].end, new IntImm(1)));
ASSERT_TRUE(exprEquals(flat_bounds, after_bounds));
}

View File

@ -149,7 +149,7 @@ TEST(Simplify, ConstantFoldWithVar) {
ExprHandle body = x * (ExprHandle(2) + ExprHandle(4));
ExprHandle newF = IRSimplifier::simplify(body);
const Mul* root = newF.AsNode<Mul>();
Mul* root = newF.AsNode<Mul>();
ASSERT_NE(root, nullptr);
ASSERT_NE(dynamic_cast<const IntImm*>(root->lhs()), nullptr);
@ -163,7 +163,7 @@ TEST(Simplify, ConstantFoldWithVar) {
ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
ExprHandle newF = IRSimplifier::simplify(body);
const Mul* root = newF.AsNode<Mul>();
Mul* root = newF.AsNode<Mul>();
ASSERT_NE(root, nullptr);
ASSERT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
@ -296,7 +296,7 @@ TEST(Simplify, UnFoldableExpr) {
ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y);
ExprHandle newF = IRSimplifier::simplify(body);
const Add* root = newF.AsNode<Add>();
Add* root = newF.AsNode<Add>();
ASSERT_NE(root, nullptr);
ASSERT_EQ(dynamic_cast<const FloatImm*>(root->lhs()), nullptr);
ASSERT_EQ(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
@ -334,7 +334,7 @@ TEST(Simplify, HashEquivalence) {
VarHandle y("y", kFloat);
ExprHandle f = (x * y) + (x * y);
const Add* root = f.AsNode<Add>();
Add* root = f.AsNode<Add>();
ASSERT_NE(root, nullptr);
HashProvider hasher;
@ -370,7 +370,7 @@ TEST(Simplify, HashEquivalenceRand) {
ExprHandle f =
Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt);
const Add* root = f.AsNode<Add>();
Add* root = f.AsNode<Add>();
ASSERT_NE(root, nullptr);
HashProvider hasher;
@ -415,7 +415,7 @@ TEST(Simplify, HashDifferenceTypes) {
KernelScope kernel_scope;
HashProvider hasher;
std::vector<const Expr*> immediates;
std::vector<Expr*> immediates;
immediates.push_back(new DoubleImm(1));
immediates.push_back(new FloatImm(1));
@ -546,9 +546,9 @@ TEST(Simplify, SimplifyAdd) {
ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4);
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
Add* root = simplified.AsNode<Add>();
ASSERT_NE(root, nullptr);
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
Var* lhs = dynamic_cast<Var*>(root->lhs());
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->name_hint(), "x");
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
@ -563,12 +563,12 @@ TEST(Simplify, SimplifySub) {
ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4);
ExprHandle simplified = IRSimplifier::simplify(body);
const Sub* root = simplified.AsNode<Sub>();
Sub* root = simplified.AsNode<Sub>();
ASSERT_NE(root, nullptr);
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->value(), -2.f);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
Var* rhs = dynamic_cast<Var*>(root->rhs());
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->name_hint(), "x");
}
@ -594,12 +594,12 @@ TEST(Simplify, SimplifyMultiTerm) {
(ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
ExprHandle simplified = IRSimplifier::simplify(body);
const Mul* root = simplified.AsNode<Mul>();
Mul* root = simplified.AsNode<Mul>();
ASSERT_NE(root, nullptr);
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->value(), 2);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
Var* rhs = dynamic_cast<Var*>(root->rhs());
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->name_hint(), "x");
}
@ -612,12 +612,12 @@ TEST(Simplify, SimplifyCasts) {
(ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
ExprHandle simplified = IRSimplifier::simplify(body);
const Mul* root = simplified.AsNode<Mul>();
Mul* root = simplified.AsNode<Mul>();
ASSERT_NE(root, nullptr);
const LongImm* lhs = dynamic_cast<const LongImm*>(root->lhs());
ASSERT_NE(lhs, nullptr);
ASSERT_EQ(lhs->value(), 2);
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
Var* rhs = dynamic_cast<Var*>(root->rhs());
ASSERT_NE(rhs, nullptr);
ASSERT_EQ(rhs->name_hint(), "x");
}
@ -629,7 +629,7 @@ TEST(Simplify, SimplifyEliminatesNoOps) {
ExprHandle body = (x + ExprHandle(0)) * 1;
ExprHandle simplified = IRSimplifier::simplify(body);
const Var* root = simplified.AsNode<Var>();
Var* root = simplified.AsNode<Var>();
ASSERT_NE(root, nullptr);
ASSERT_EQ(root->name_hint(), "x");
}
@ -643,16 +643,16 @@ TEST(Simplify, SimplifyMultiVar) {
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
Add* root = simplified.AsNode<Add>();
ASSERT_NE(root, nullptr);
const Mul* lhs = dynamic_cast<const Mul*>(root->lhs());
Mul* lhs = dynamic_cast<Mul*>(root->lhs());
ASSERT_NE(lhs, nullptr);
const Var* varX = dynamic_cast<const Var*>(lhs->rhs());
Var* varX = dynamic_cast<Var*>(lhs->rhs());
ASSERT_NE(varX, nullptr);
ASSERT_EQ(varX->name_hint(), "y");
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
Mul* rhs = dynamic_cast<Mul*>(root->rhs());
ASSERT_NE(rhs, nullptr);
const Var* varY = dynamic_cast<const Var*>(rhs->rhs());
Var* varY = dynamic_cast<Var*>(rhs->rhs());
ASSERT_NE(varY, nullptr);
ASSERT_EQ(varY->name_hint(), "x");
}
@ -665,7 +665,7 @@ TEST(Simplify, DISABLED_SimplifyReorderings) {
ExprHandle body = x + 2 + y;
ExprHandle simplified = IRSimplifier::simplify(body);
const Add* root = simplified.AsNode<Add>();
Add* root = simplified.AsNode<Add>();
ASSERT_NE(root, nullptr);
IS_NODE_WITH_NAME(Add, root->lhs(), rhs);

View File

@ -127,13 +127,13 @@ int main(int argc, char* argv[]) {
// Let's start with defining a domain. We do this by creating a Buf object.
// First, let's specify the sizes:
std::vector<const Expr*> dims = {
std::vector<Expr*> dims = {
new IntImm(64), new IntImm(32)}; // IntImm stands for Integer Immediate
// and represents an integer constant
// Now we can create a Buf object by providing a name, dimensions, and a
// data type of the elements:
const Buf* buf = new Buf("X", dims, kInt);
Buf* buf = new Buf("X", dims, kInt);
// Next we need to spefify the computation. We can do that by either
// constructing a complete tensor statement for it (statements are
@ -144,9 +144,9 @@ int main(int argc, char* argv[]) {
// Let's define two variables, i and j - they will be axis in our
// computation.
const Var* i = new Var("i", kInt);
const Var* j = new Var("j", kInt);
std::vector<const Var*> args = {i, j};
Var* i = new Var("i", kInt);
Var* j = new Var("j", kInt);
std::vector<Var*> args = {i, j};
// Now we can define the body of the tensor computation using these
// variables. What this means is that values in our tensor are:

View File

@ -19,7 +19,7 @@ class HasRand : public IRVisitor {
}
private:
void visit(const Intrinsics* v) override {
void visit(Intrinsics* v) override {
if (v->op_type() == IntrinsicsOp::kRand) {
has_rand_ = true;
} else {
@ -34,18 +34,18 @@ template <typename Node>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class NodeFinder : public IRVisitor {
public:
void visit(const Node* v) override {
void visit(Node* v) override {
nodes.push_back((Node*)v);
IRVisitor::visit(v);
}
static std::vector<Node*> find(const Stmt* s) {
static std::vector<Node*> find(Stmt* s) {
NodeFinder<Node> nf;
s->accept(&nf);
return nf.nodes;
}
static std::vector<Node*> find(const Expr* e) {
static std::vector<Node*> find(Expr* e) {
NodeFinder<Node> nf;
e->accept(&nf);
return nf.nodes;
@ -57,108 +57,108 @@ class NodeFinder : public IRVisitor {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class VarFinder : public IRVisitor {
public:
void visit(const Var* v) override {
void visit(Var* v) override {
vars_.insert(v);
IRVisitor::visit(v);
}
static std::unordered_set<const Var*> find(Stmt* s) {
static std::unordered_set<Var*> find(Stmt* s) {
VarFinder nf;
s->accept(&nf);
return nf.vars();
}
static std::unordered_set<const Var*> find(const Expr* e) {
static std::unordered_set<Var*> find(Expr* e) {
VarFinder nf;
e->accept(&nf);
return nf.vars();
}
const std::unordered_set<const Var*>& vars() {
const std::unordered_set<Var*>& vars() {
return vars_;
}
private:
std::unordered_set<const Var*> vars_;
std::unordered_set<Var*> vars_;
};
class BufFinder : public IRVisitor {
public:
void visit(const Buf* v) override {
void visit(Buf* v) override {
bufs_.insert(v);
IRVisitor::visit(v);
}
static std::unordered_set<const Buf*> find(Stmt* s) {
static std::unordered_set<Buf*> find(Stmt* s) {
BufFinder nf;
s->accept(&nf);
return nf.bufs();
}
static std::unordered_set<const Buf*> find(const Expr* e) {
static std::unordered_set<Buf*> find(Expr* e) {
BufFinder nf;
e->accept(&nf);
return nf.bufs();
}
const std::unordered_set<const Buf*>& bufs() {
const std::unordered_set<Buf*>& bufs() {
return bufs_;
}
private:
std::unordered_set<const Buf*> bufs_;
std::unordered_set<Buf*> bufs_;
};
// Finds all kinds of write operations to the provided Buf.
class WritesToBuf : public IRVisitor {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
WritesToBuf(const Buf* target) : target_(target) {}
WritesToBuf(Buf* target) : target_(target) {}
std::vector<const Stmt*> writes() {
std::vector<Stmt*> writes() {
return writes_;
}
static std::vector<const Stmt*> find(Stmt* s, const Buf* b) {
static std::vector<Stmt*> find(Stmt* s, Buf* b) {
WritesToBuf finder(b);
s->accept(&finder);
return finder.writes();
}
private:
void visit(const Store* v) override {
void visit(Store* v) override {
if (v->buf() == target_) {
writes_.push_back(v);
}
}
void visit(const AtomicAdd* v) override {
void visit(AtomicAdd* v) override {
if (v->buf() == target_) {
writes_.push_back(v);
}
}
const Buf* target_;
std::vector<const Stmt*> writes_;
Buf* target_;
std::vector<Stmt*> writes_;
};
class StmtsReadingBuf : public IRVisitor {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
StmtsReadingBuf(const Buf* target) : target_(target) {}
StmtsReadingBuf(Buf* target) : target_(target) {}
std::vector<const Stmt*> reads() {
std::vector<Stmt*> reads() {
return reads_;
}
static std::vector<const Stmt*> find(Stmt* s, const Buf* b) {
static std::vector<Stmt*> find(Stmt* s, Buf* b) {
StmtsReadingBuf finder(b);
s->accept(&finder);
return finder.reads();
}
private:
bool readsBuffer(const Stmt* s) {
bool readsBuffer(Stmt* s) {
auto loads = NodeFinder<Load>::find(s);
for (auto l : loads) {
if (l->buf() == target_) {
@ -168,40 +168,40 @@ class StmtsReadingBuf : public IRVisitor {
return false;
}
void visit(const Store* v) override {
void visit(Store* v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
void visit(const Let* v) override {
void visit(Let* v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
void visit(const Cond* v) override {
void visit(Cond* v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
void visit(const AtomicAdd* v) override {
void visit(AtomicAdd* v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
const Buf* target_;
std::vector<const Stmt*> reads_;
Buf* target_;
std::vector<Stmt*> reads_;
};
// Traverses the IR to determine if a particular Var is modified within it.
class ModifiesVarChecker : public IRVisitor {
public:
ModifiesVarChecker(const Var* v) : var_(v) {}
ModifiesVarChecker(Var* v) : var_(v) {}
static bool check(const Stmt* s, const Var* v) {
static bool check(Stmt* s, Var* v) {
ModifiesVarChecker checker(v);
s->accept(&checker);
return checker.found();
@ -212,7 +212,7 @@ class ModifiesVarChecker : public IRVisitor {
}
private:
void visit(const Store* v) override {
void visit(Store* v) override {
if (v->buf()->base_handle() == var_) {
found_ = true;
return;
@ -220,7 +220,7 @@ class ModifiesVarChecker : public IRVisitor {
IRVisitor::visit(v);
}
void visit(const AtomicAdd* v) override {
void visit(AtomicAdd* v) override {
if (v->buf()->base_handle() == var_) {
found_ = true;
return;
@ -228,7 +228,7 @@ class ModifiesVarChecker : public IRVisitor {
IRVisitor::visit(v);
}
void visit(const Let* v) override {
void visit(Let* v) override {
if (v->var() == var_) {
found_ = true;
return;
@ -236,7 +236,7 @@ class ModifiesVarChecker : public IRVisitor {
IRVisitor::visit(v);
}
void visit(const For* v) override {
void visit(For* v) override {
if (v->var() == var_) {
found_ = true;
return;
@ -244,7 +244,7 @@ class ModifiesVarChecker : public IRVisitor {
IRVisitor::visit(v);
}
const Var* var_;
Var* var_;
bool found_{false};
};
@ -252,26 +252,26 @@ class ModifiesVarChecker : public IRVisitor {
// It creates a map of multi dim buffers and their flat verions
class CreateBufferMap : public IRVisitor {
public:
const std::unordered_map<std::string, const Buf*>& getBufferMap() const {
const std::unordered_map<std::string, Buf*>& getBufferMap() const {
return map_input_to_tensor_bufs_;
}
private:
void visit(const Store* v) override {
auto load_node = dynamic_cast<const Load*>(v->value());
void visit(Store* v) override {
auto load_node = dynamic_cast<Load*>(v->value());
if (load_node) {
auto t_buf = load_node->buf();
map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf());
} else {
auto add_node = dynamic_cast<const Add*>(v->value());
auto mul_node = dynamic_cast<const Mul*>(v->value());
auto add_node = dynamic_cast<Add*>(v->value());
auto mul_node = dynamic_cast<Mul*>(v->value());
// This means for now, v->value() can be Add or Mul
TORCH_INTERNAL_ASSERT((add_node || mul_node));
map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), v->buf());
}
v->value()->accept(this);
}
std::unordered_map<std::string, const Buf*> map_input_to_tensor_bufs_;
std::unordered_map<std::string, Buf*> map_input_to_tensor_bufs_;
};
} // namespace tensorexpr

View File

@ -32,8 +32,7 @@ std::string blockDtypeCppString(const Dtype& dtype) {
}
}
bool BlockAnalysis::areBufsInMap(
const std::unordered_set<const Buf*>& bufs) const {
bool BlockAnalysis::areBufsInMap(const std::unordered_set<Buf*>& bufs) const {
for (auto const& arg : bufs) {
auto got = map_input_to_tensor_bufs_.find(arg->name_hint());
if (got == map_input_to_tensor_bufs_.end()) {
@ -43,7 +42,7 @@ bool BlockAnalysis::areBufsInMap(
return true;
}
const Buf* BlockAnalysis::getMultiDimBuf(const Buf* buf) const {
Buf* BlockAnalysis::getMultiDimBuf(Buf* buf) const {
auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
if (input_ != map_input_to_tensor_bufs_.end()) {
return input_->second;
@ -52,7 +51,7 @@ const Buf* BlockAnalysis::getMultiDimBuf(const Buf* buf) const {
}
}
std::string BlockAnalysis::getInputName(const Buf* buf) const {
std::string BlockAnalysis::getInputName(Buf* buf) const {
auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
if (input_ != map_input_to_tensor_bufs_.end()) {
return input_->second->name_hint();
@ -61,23 +60,23 @@ std::string BlockAnalysis::getInputName(const Buf* buf) const {
}
}
void BlockAnalysis::visit(const Store* v) {
void BlockAnalysis::visit(Store* v) {
store_targets_.insert(v->buf());
v->value()->accept(this);
}
void BlockAnalysis::visit(const Load* v) {
void BlockAnalysis::visit(Load* v) {
loads_.insert(v->buf());
}
void BlockAnalysis::visit(const For* v) {
void BlockAnalysis::visit(For* v) {
const LoopOptions& loop_options = v->loop_options();
if (loop_options.is_gpu_block_index()) {
map_input_to_tensor_bufs_ = loop_options.get_buffer_mapping();
v->body()->accept(this);
} else if (loop_options.is_gpu_thread_index()) {
auto block_size = v->stop();
block_size_ = dynamic_cast<const IntImm*>(block_size)->value();
block_size_ = dynamic_cast<IntImm*>(block_size)->value();
v->body()->accept(this);
} else {
IRVisitor::visit(v);
@ -91,26 +90,26 @@ void BlockAnalysis::visit(const For* v) {
// TODO: When handling fused ops d = a + b + c, the correct
// way would be to mutate the expression to Block version and print.
void BlockPrinter::visit(const Add* v) {
void BlockPrinter::visit(Add* v) {
emitIndent();
os() << "add(";
v->lhs()->accept(this);
v->rhs()->accept(this);
}
void BlockPrinter::visit(const Mul* v) {
void BlockPrinter::visit(Mul* v) {
emitIndent();
os() << "mul(";
v->lhs()->accept(this);
v->rhs()->accept(this);
}
void BlockPrinter::visit(const For* v) {
void BlockPrinter::visit(For* v) {
const LoopOptions& loop_options = v->loop_options();
auto buf_reads = block_analysis_->loads();
auto buf_writes = block_analysis_->stores();
std::unordered_set<const Buf*> bufs(buf_reads.begin(), buf_reads.end());
std::unordered_set<Buf*> bufs(buf_reads.begin(), buf_reads.end());
bufs.insert(buf_writes.begin(), buf_writes.end());
if (loop_options.is_gpu_block_index()) {
@ -146,9 +145,9 @@ void BlockPrinter::visit(const For* v) {
}
}
void BlockPrinter::PrintTensorInfo(const std::unordered_set<const Buf*>& bufs) {
void BlockPrinter::PrintTensorInfo(const std::unordered_set<Buf*>& bufs) {
os() << "tensors {";
for (const auto& buf : bufs) {
for (auto& buf : bufs) {
os() << std::endl;
emitIndent();
emitIndent();
@ -162,7 +161,7 @@ void BlockPrinter::PrintTensorInfo(const std::unordered_set<const Buf*>& bufs) {
os() << "}";
}
for (const auto& buf : bufs) {
for (auto& buf : bufs) {
os() << std::endl;
emitIndent();
emitIndent();
@ -179,19 +178,19 @@ void BlockPrinter::PrintTensorInfo(const std::unordered_set<const Buf*>& bufs) {
os() << "}" << std::endl << std::endl;
}
void BlockPrinter::PrintArguments(const std::unordered_set<const Buf*>& bufs) {
for (const auto& buf : bufs) {
void BlockPrinter::PrintArguments(const std::unordered_set<Buf*>& bufs) {
for (auto& buf : bufs) {
auto multidimbuf = block_analysis_->getMultiDimBuf(buf);
auto num_dims = multidimbuf->dims().size();
// The dims for the multi-dim tensors
for (unsigned long d = 0; d < num_dims; d++) {
auto dim_val = dynamic_cast<const IntImm*>(multidimbuf->dim(d));
auto dim_val = dynamic_cast<IntImm*>(multidimbuf->dim(d));
this->dim_values_map.emplace(this->dim_names[d], dim_val->value());
}
// The dimensions for the flattened tensors
auto val = dynamic_cast<const IntImm*>(buf->dim(0));
auto val = dynamic_cast<IntImm*>(buf->dim(0));
if (block_analysis_->is_buf_store_target(buf)) {
this->dim_values_map.emplace(
this->flat_dim_names[num_dims - 1], val->value());
@ -217,10 +216,10 @@ void BlockPrinter::PrintArguments(const std::unordered_set<const Buf*>& bufs) {
os() << "}" << std::endl << std::endl;
}
void BlockPrinter::PrintBufferInfo(const std::unordered_set<const Buf*>& bufs) {
void BlockPrinter::PrintBufferInfo(const std::unordered_set<Buf*>& bufs) {
emitIndent();
os() << "buffers {";
for (const auto& read : bufs) {
for (auto& read : bufs) {
os() << std::endl;
emitIndent();
emitIndent();
@ -234,11 +233,10 @@ void BlockPrinter::PrintBufferInfo(const std::unordered_set<const Buf*>& bufs) {
os() << "}" << std::endl << std::endl;
}
void BlockPrinter::PrintDistribution(
const std::unordered_set<const Buf*>& bufs) {
void BlockPrinter::PrintDistribution(const std::unordered_set<Buf*>& bufs) {
emitIndent();
os() << "distribution {" << std::endl;
for (const auto& buf : bufs) {
for (auto& buf : bufs) {
emitIndent();
emitIndent();
auto buf_name = buf->name_hint();
@ -249,12 +247,12 @@ void BlockPrinter::PrintDistribution(
}
void BlockPrinter::PrintLoop(
const std::unordered_set<const Buf*>& bufs,
const std::unordered_set<Buf*>& bufs,
bool block_idx) {
emitIndent();
os() << "loop (";
auto trip = 0;
for (const auto& buf : bufs) {
for (auto& buf : bufs) {
if (trip > 0) {
os() << ",";
}
@ -267,9 +265,9 @@ void BlockPrinter::PrintLoop(
}
void BlockPrinter::PrintReshapeInfo(
const std::unordered_set<const Buf*>& bufs,
const std::unordered_set<Buf*>& bufs,
bool reverse) {
for (const auto& buf : bufs) {
for (auto& buf : bufs) {
emitIndent();
os() << "reshape("
<< (reverse ? block_analysis_->getFlatInputName(buf)
@ -281,17 +279,16 @@ void BlockPrinter::PrintReshapeInfo(
}
}
void BlockPrinter::PrintDMAs(const std::unordered_set<const Buf*>& bufs) {
for (const auto& read : bufs) {
void BlockPrinter::PrintDMAs(const std::unordered_set<Buf*>& bufs) {
for (auto& read : bufs) {
emitIndent();
os() << "dma_in(";
os() << block_analysis_->getFlatInputName(read);
os() << ")" << std::endl;
}
}
void BlockPrinter::PrintAdjustBuffers(
const std::unordered_set<const Buf*>& bufs) {
for (const auto& read : bufs) {
void BlockPrinter::PrintAdjustBuffers(const std::unordered_set<Buf*>& bufs) {
for (auto& read : bufs) {
emitIndent();
os() << "adjust_buffer(";
os() << block_analysis_->getFlatInputName(read);
@ -299,16 +296,16 @@ void BlockPrinter::PrintAdjustBuffers(
}
}
void BlockPrinter::visit(const Load* v) {
void BlockPrinter::visit(Load* v) {
os() << block_analysis_->getFlatInputName(v->buf()) << ".buffer, ";
}
void BlockPrinter::visit(const Store* v) {
void BlockPrinter::visit(Store* v) {
emitIndent();
os() << *v->value() << block_analysis_->getFlatInputName(v->buf())
<< ".tensor)" << std::endl;
}
void BlockPrinter::visit(const Block* v) {
void BlockPrinter::visit(Block* v) {
os() << "{" << std::endl;
indent_++;
for (Stmt* s : v->stmts()) {
@ -338,7 +335,7 @@ void BlockCodeGen::Initialize() {
auto buf_reads = block_analysis_->loads();
auto buf_writes = block_analysis_->stores();
// Ensure all Bufs in reads/writes are in the map
std::unordered_set<const Buf*> bufs(buf_reads.begin(), buf_reads.end());
std::unordered_set<Buf*> bufs(buf_reads.begin(), buf_reads.end());
bufs.insert(buf_writes.begin(), buf_writes.end());
if (!block_analysis_->areBufsInMap(bufs)) {
throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map");

View File

@ -20,15 +20,15 @@ namespace tensorexpr {
// A class that analyzes the given program relevant for Block backend.
class BlockAnalysis : public IRVisitor {
public:
bool is_buf_store_target(const Buf* buf) const {
bool is_buf_store_target(Buf* buf) const {
return store_targets_.count(buf) > 0;
}
const std::unordered_set<const Buf*>& loads() const {
const std::unordered_set<Buf*>& loads() const {
return loads_;
}
const std::unordered_set<const Buf*>& stores() const {
const std::unordered_set<Buf*>& stores() const {
return store_targets_;
}
@ -36,64 +36,62 @@ class BlockAnalysis : public IRVisitor {
return block_size_;
}
bool areBufsInMap(const std::unordered_set<const Buf*>& bufs) const;
bool areBufsInMap(const std::unordered_set<Buf*>& bufs) const;
const Buf* getMultiDimBuf(const Buf* buf) const;
Buf* getMultiDimBuf(Buf* buf) const;
std::string getInputName(const Buf* buf) const;
std::string getInputName(Buf* buf) const;
std::string getFlatInputName(const Buf* buf) const {
std::string getFlatInputName(Buf* buf) const {
return getInputName(buf) + "_flat";
}
std::unordered_map<std::string, const Buf*> getBufferMap() const {
std::unordered_map<std::string, Buf*> getBufferMap() const {
return map_input_to_tensor_bufs_;
}
private:
void visit(const Store* v) override;
void visit(const Load* v) override;
void visit(const For* v) override;
void visit(Store* v) override;
void visit(Load* v) override;
void visit(For* v) override;
std::unordered_map<std::string, const Buf*> map_input_to_tensor_bufs_;
std::unordered_set<const Buf*> store_targets_;
std::unordered_set<const Buf*> loads_;
std::unordered_map<std::string, Buf*> map_input_to_tensor_bufs_;
std::unordered_set<Buf*> store_targets_;
std::unordered_set<Buf*> loads_;
int block_size_ = 32;
};
// A class that overrides the underlying IRPrinter to produce Block.
class BlockPrinter : public IRPrinter {
public:
BlockPrinter(std::ostream* os, const BlockAnalysis* block_analysis)
BlockPrinter(std::ostream* os, BlockAnalysis* block_analysis)
: IRPrinter(*os), block_analysis_(block_analysis) {}
using IRPrinter::name_manager;
using IRPrinter::visit;
private:
const BlockAnalysis* block_analysis_;
BlockAnalysis* block_analysis_;
std::unordered_map<std::string, int> dim_values_map;
std::vector<std::string> dim_names = {"N", "H", "W", "C"};
std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"};
void PrintTensorInfo(const std::unordered_set<const Buf*>& bufs);
void PrintArguments(const std::unordered_set<const Buf*>& bufs);
void PrintBufferInfo(const std::unordered_set<const Buf*>& bufs);
void PrintDistribution(const std::unordered_set<const Buf*>& bufs);
void PrintLoop(
const std::unordered_set<const Buf*>& bufs,
bool block_idx = true);
void PrintTensorInfo(const std::unordered_set<Buf*>& bufs);
void PrintArguments(const std::unordered_set<Buf*>& bufs);
void PrintBufferInfo(const std::unordered_set<Buf*>& bufs);
void PrintDistribution(const std::unordered_set<Buf*>& bufs);
void PrintLoop(const std::unordered_set<Buf*>& bufs, bool block_idx = true);
void PrintReshapeInfo(
const std::unordered_set<const Buf*>& bufs,
const std::unordered_set<Buf*>& bufs,
bool reverse = false);
void PrintDMAs(const std::unordered_set<const Buf*>& bufs);
void PrintAdjustBuffers(const std::unordered_set<const Buf*>& bufs);
void PrintDMAs(const std::unordered_set<Buf*>& bufs);
void PrintAdjustBuffers(const std::unordered_set<Buf*>& bufs);
void visit(const For* v) override;
void visit(const Load* v) override;
void visit(const Store* v) override;
void visit(const Block* v) override;
void visit(const Add* v) override;
void visit(const Mul* v) override;
void visit(For* v) override;
void visit(Load* v) override;
void visit(Store* v) override;
void visit(Block* v) override;
void visit(Add* v) override;
void visit(Mul* v) override;
};
class TORCH_API BlockCodeGen : public CodeGen {

View File

@ -19,7 +19,7 @@ using namespace analysis;
template <typename Container>
BoundsInfo mergeTensorAccesses(
const Container& accesses,
const std::unordered_map<const Var*, const Buf*>& varToBuf,
const std::unordered_map<Var*, Buf*>& varToBuf,
bool distinctAccessKinds) {
BoundsInfo ret;
for (auto& access : accesses) {
@ -30,7 +30,7 @@ BoundsInfo mergeTensorAccesses(
auto vtbIt = varToBuf.find(access->var());
TORCH_INTERNAL_ASSERT(vtbIt != varToBuf.end());
const Buf* buf = vtbIt->second;
Buf* buf = vtbIt->second;
std::vector<TensorAccessBoundsInfo>& infos = ret[buf];
bool added = false;
@ -70,20 +70,20 @@ BoundsInfo mergeTensorAccesses(
return ret;
}
std::unordered_map<const Var*, const Buf*> getAllBufs(Stmt* s) {
std::unordered_map<const Var*, const Buf*> varToBuf;
std::unordered_map<Var*, Buf*> getAllBufs(Stmt* s) {
std::unordered_map<Var*, Buf*> varToBuf;
auto bufs = NodeFinder<const Buf>::find(s);
auto bufs = NodeFinder<Buf>::find(s);
for (auto* b : bufs) {
varToBuf[b->base_handle()] = b;
}
return varToBuf;
}
std::unordered_map<const Var*, const Buf*> getAllBufs(Expr* e) {
std::unordered_map<const Var*, const Buf*> varToBuf;
std::unordered_map<Var*, Buf*> getAllBufs(Expr* e) {
std::unordered_map<Var*, Buf*> varToBuf;
auto bufs = NodeFinder<const Buf>::find(e);
auto bufs = NodeFinder<Buf>::find(e);
for (auto* b : bufs) {
varToBuf[b->base_handle()] = b;
}
@ -121,7 +121,7 @@ void printBoundsInfo(const BoundsInfo& v) {
for (auto& pair : v) {
std::cerr << *pair.first << " in [";
bool first = true;
for (const auto& b : pair.second) {
for (auto& b : pair.second) {
if (!first) {
std::cerr << ", ";
}
@ -130,7 +130,7 @@ void printBoundsInfo(const BoundsInfo& v) {
if (b.start.empty()) {
std::cerr << "0";
}
for (const auto& s : b.start) {
for (auto& s : b.start) {
if (i != 0) {
std::cerr << ", ";
}
@ -142,7 +142,7 @@ void printBoundsInfo(const BoundsInfo& v) {
if (b.stop.empty()) {
std::cerr << "0";
}
for (const auto& s : b.stop) {
for (auto& s : b.stop) {
if (i != 0) {
std::cerr << ", ";
}
@ -157,15 +157,15 @@ void printBoundsInfo(const BoundsInfo& v) {
std::cerr << "}\n";
}
std::vector<const Expr*> getBoundExtents(
std::vector<Expr*> getBoundExtents(
const std::vector<TensorAccessBoundsInfo>& infos) {
std::vector<const Expr*> starts;
std::vector<const Expr*> stops;
std::vector<Expr*> starts;
std::vector<Expr*> stops;
// Find the safe size of the temprorary buffer by determining the outer
// extents of a union of all bounds.
for (const TensorAccessBoundsInfo& p : infos) {
for (const auto i : c10::irange(p.start.size())) {
for (auto i : c10::irange(p.start.size())) {
if (starts.size() <= i) {
starts.push_back(p.start[i]);
} else {
@ -181,9 +181,9 @@ std::vector<const Expr*> getBoundExtents(
}
}
std::vector<const Expr*> extents;
std::vector<Expr*> extents;
for (size_t i = 0; i < starts.size(); ++i) {
const Expr* dim = IRSimplifier::simplify(
Expr* dim = IRSimplifier::simplify(
new Add(new Sub(stops[i], starts[i]), new IntImm(1)));
extents.push_back(dim);
@ -210,7 +210,7 @@ BoundSet convertBounds(
BoundSet convertBounds(
BoundsInfo& bounds,
const Buf* buf,
Buf* buf,
TensorAccessKind filter = kMutate) {
auto it = bounds.find(buf);
if (it == bounds.end()) {
@ -231,7 +231,7 @@ HazardKind getPotentialHazards(
BoundSet aReads;
for (auto& pair : bBounds) {
const Buf* buf = pair.first;
Buf* buf = pair.first;
if (aBounds.find(buf) == aBounds.end()) {
continue;
}
@ -302,18 +302,17 @@ bool hasConflictingOverlap(
const BoundsInfo& bBounds,
TensorAccessKind aFilter = kMutate,
TensorAccessKind bFilter = kMutate) {
using IndexBoundsInfo =
std::unordered_map<const Buf*, std::vector<IndexBounds>>;
using IndexBoundsInfo = std::unordered_map<Buf*, std::vector<IndexBounds>>;
IndexBoundsInfo aIndexBoundsInfo;
for (const auto& aBound : aBounds) {
for (auto& aBound : aBounds) {
aIndexBoundsInfo[aBound.first] = getIndexBounds(aBound.second, aFilter);
}
IndexBoundsInfo bIndexBoundsInfo;
for (const auto& bBound : bBounds) {
for (auto& bBound : bBounds) {
bIndexBoundsInfo[bBound.first] = getIndexBounds(bBound.second, bFilter);
}
for (const auto& aBound : aBounds) {
for (auto& aBound : aBounds) {
auto bIt = bBounds.find(aBound.first);
if (bIt == bBounds.end()) {
continue;

View File

@ -20,12 +20,12 @@ enum C10_API_ENUM TensorAccessKind { kLoad, kStore, kMutate };
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct TORCH_API TensorAccessBoundsInfo {
TensorAccessKind kind;
std::vector<const Expr*> start;
std::vector<const Expr*> stop;
std::vector<Expr*> start;
std::vector<Expr*> stop;
};
using BoundsInfo =
std::unordered_map<const Buf*, std::vector<TensorAccessBoundsInfo>>;
std::unordered_map<Buf*, std::vector<TensorAccessBoundsInfo>>;
TORCH_API BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds = true);
@ -42,7 +42,7 @@ TORCH_API BoundsInfo getInferredBounds(
TORCH_API void printBoundsInfo(const BoundsInfo& v);
TORCH_API std::vector<const Expr*> getBoundExtents(
TORCH_API std::vector<Expr*> getBoundExtents(
const std::vector<TensorAccessBoundsInfo>& infos);
// The kind of dependency found, in increasing order of exclusivity.

View File

@ -14,8 +14,8 @@ OverlapKind boundOverlap(Bound a, Bound b) {
return ContainedOrEqual;
}
const Expr* lowDiff = IRSimplifier::simplify(new Sub(a.start, b.end));
const Expr* highDiff = IRSimplifier::simplify(new Sub(b.start, a.end));
Expr* lowDiff = IRSimplifier::simplify(new Sub(a.start, b.end));
Expr* highDiff = IRSimplifier::simplify(new Sub(b.start, a.end));
if (lowDiff->isConstant() && highDiff->isConstant()) {
int low = immediateAs<int>(lowDiff);
@ -26,8 +26,8 @@ OverlapKind boundOverlap(Bound a, Bound b) {
}
}
const Expr* diff_start = IRSimplifier::simplify(new Sub(b.start, a.start));
const Expr* diff_end = IRSimplifier::simplify(new Sub(b.end, a.end));
Expr* diff_start = IRSimplifier::simplify(new Sub(b.start, a.start));
Expr* diff_end = IRSimplifier::simplify(new Sub(b.end, a.end));
// If one side fully encloses the other, they're adjacent.
if (diff_start->isConstant() && diff_end->isConstant()) {
@ -122,8 +122,8 @@ std::vector<Bound> subtractBound(Bound a, Bound b, OverlapKind overlap) {
return {a};
}
const Expr* lowDiff = IRSimplifier::simplify(new Sub(b.start, a.start));
const Expr* highDiff = IRSimplifier::simplify(new Sub(b.end, a.end));
Expr* lowDiff = IRSimplifier::simplify(new Sub(b.start, a.start));
Expr* highDiff = IRSimplifier::simplify(new Sub(b.end, a.end));
// If the diff has only a single var, we can try to guess sign.
if (!lowDiff->isConstant()) {
@ -161,8 +161,7 @@ std::vector<Bound> subtractBound(Bound a, Bound b, OverlapKind overlap) {
}
if (hasTail) {
const Expr* tailStart =
IRSimplifier::simplify(new Add(b.end, new IntImm(1)));
Expr* tailStart = IRSimplifier::simplify(new Add(b.end, new IntImm(1)));
res.emplace_back(tailStart, a.end);
}

View File

@ -13,8 +13,8 @@ namespace analysis {
// A simple class containing the start and end of a range in a single dimension.
struct TORCH_API Bound {
const Expr* start{nullptr};
const Expr* end{nullptr};
Expr* start{nullptr};
Expr* end{nullptr};
// This stores whether or not the start and end of this Bound have previously
// been swapped. This occurs when the bound is in a loop with a negative
@ -22,7 +22,7 @@ struct TORCH_API Bound {
bool swapped{false};
Bound() = default;
Bound(const Expr* s, const Expr* e) : start(s), end(e) {}
Bound(Expr* s, Expr* e) : start(s), end(e) {}
void print() const {
std::cout << "(" << *start << ", " << *end << ")";
@ -44,7 +44,7 @@ struct TORCH_API Bound {
struct BoundHash {
size_t operator()(const Bound& b) const {
return std::hash<const Expr*>()(b.start) ^ std::hash<const Expr*>()(b.end);
return std::hash<Expr*>()(b.start) ^ std::hash<Expr*>()(b.end);
}
};

View File

@ -14,7 +14,7 @@ RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList::
oss << "Invalid stmt codegen name: " << name << ". ";
oss << "Existing codegen names: [";
int index = 0;
for (const auto& entry : stmt_factory_methods_) {
for (auto& entry : stmt_factory_methods_) {
if (index != 0) {
oss << ", ";
}
@ -44,7 +44,7 @@ std::unique_ptr<CodeGen> CreateCodeGen(
return method(stmt, params, device, kernel_func_name);
}
const Expr* GenericIntrinsicsExpander::mutate(const Intrinsics* v) {
Expr* GenericIntrinsicsExpander::mutate(Intrinsics* v) {
if (v->op_type() == kSigmoid) {
auto x = v->param(0)->accept_mutator(this);
auto one = expr_to_vec(

View File

@ -108,11 +108,11 @@ class CodeGen::BufferArg {
BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {}
BufferArg(const BufHandle& buf) : buf_(buf.node()) {}
const Var* var() const {
Var* var() const {
return isVar_ ? var_ : buf_->base_handle();
}
const Buf* buf() const {
Buf* buf() const {
return buf_;
}
@ -125,8 +125,8 @@ class CodeGen::BufferArg {
}
private:
const Var* var_ = nullptr;
const Buf* buf_ = nullptr;
Var* var_ = nullptr;
Buf* buf_ = nullptr;
bool isVar_ = false;
};
@ -226,7 +226,7 @@ TORCH_API std::unique_ptr<CodeGen> CreateCodeGen(
class TORCH_API GenericIntrinsicsExpander : public IRMutator {
protected:
const Expr* mutate(const Intrinsics* v) override;
Expr* mutate(Intrinsics* v) override;
};
} // namespace tensorexpr

View File

@ -4,12 +4,12 @@ namespace torch {
namespace jit {
namespace tensorexpr {
void CppPrinter::visit(const Allocate* alloc) {
void CppPrinter::visit(Allocate* alloc) {
constexpr size_t kAllocOnStackThresholdSize = 512;
size_t size = 1;
for (auto dim : alloc->dims()) {
const IntImm* v = dynamic_cast<const IntImm*>(dim);
IntImm* v = dynamic_cast<IntImm*>(dim);
if (v) {
size *= v->value();
} else {
@ -30,8 +30,8 @@ void CppPrinter::visit(const Allocate* alloc) {
}
}
void CppPrinter::visit(const Free* free) {
const Var* var = free->buffer_var();
void CppPrinter::visit(Free* free) {
Var* var = free->buffer_var();
if (allocated_on_heap_.count(var)) {
emitIndent();
os() << "free(" << name_manager()->get_unique_name(var) << ");"

View File

@ -14,11 +14,11 @@ class TORCH_API CppPrinter : public IRPrinter {
explicit CppPrinter(std::ostream* os) : IRPrinter(*os) {}
using IRPrinter::visit;
void visit(const Allocate*) override;
void visit(const Free*) override;
void visit(Allocate*) override;
void visit(Free*) override;
private:
std::unordered_set<const Var*> allocated_on_heap_;
std::unordered_set<Var*> allocated_on_heap_;
};
} // namespace tensorexpr

View File

@ -21,7 +21,7 @@ namespace tensorexpr {
// TODO: move this to a more shared place.
class ScopedVarName {
public:
ScopedVarName(VarNameMap* mapping, const Var* var, const std::string& name)
ScopedVarName(VarNameMap* mapping, Var* var, const std::string& name)
: mapping_(mapping), var_(var) {
auto iter = mapping->find(var);
if (iter != mapping->end()) {
@ -30,10 +30,7 @@ class ScopedVarName {
mapping->insert(std::make_pair(var, name));
}
ScopedVarName(
UniqueNameManager* manager,
const Var* var,
const std::string& name)
ScopedVarName(UniqueNameManager* manager, Var* var, const std::string& name)
: ScopedVarName(&manager->unique_name_mapping_, var, name) {}
ScopedVarName(const ScopedVarName&) = delete;
@ -45,11 +42,11 @@ class ScopedVarName {
private:
VarNameMap* mapping_ = nullptr;
const Var* var_ = nullptr;
Var* var_ = nullptr;
};
static int as_int(const Expr* expr) {
auto v = dynamic_cast<const IntImm*>(expr);
static int as_int(Expr* expr) {
auto v = dynamic_cast<IntImm*>(expr);
if (!v) {
throw malformed_input(
"cuda_codegen: non Int expr interpreted as int", expr);
@ -58,7 +55,7 @@ static int as_int(const Expr* expr) {
return v->value();
}
static bool is_zero(const Expr* expr) {
static bool is_zero(Expr* expr) {
return as_int(expr) == 0;
}
@ -123,17 +120,17 @@ std::string CudaPrinter::dtypeToCppString(const Dtype& dtype) {
}
}
void CudaAnalysis::visit(const Free* v) {
void CudaAnalysis::visit(Free* v) {
if (thread_local_bufs_.count(v->buffer_var()) == 0 &&
cross_block_bufs_.count(v->buffer_var()) == 0) {
throw std::runtime_error("Global free not supported yet");
}
}
void CudaAnalysis::visit(const Allocate* v) {
void CudaAnalysis::visit(Allocate* v) {
Stmt* p = v->get_parent();
while (p) {
const For* for_v = dynamic_cast<const For*>(p);
For* for_v = dynamic_cast<For*>(p);
if (for_v) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (for_v->loop_options().is_gpu_block_index()) {
@ -151,7 +148,7 @@ void CudaAnalysis::visit(const Allocate* v) {
throw std::runtime_error("Global alloc not supported yet");
}
void CudaAnalysis::visit(const For* v) {
void CudaAnalysis::visit(For* v) {
// Recurse first.
v->body()->accept(this);
@ -161,7 +158,7 @@ void CudaAnalysis::visit(const For* v) {
if (gpu_block_index >= 3) {
throw std::runtime_error("support only 3D gpu_block_index");
}
const Expr* prev = nullptr;
Expr* prev = nullptr;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
// NOLINTNEXTLINE(bugprone-branch-clone)
if (gpu_block_extents_.size() <= gpu_block_index) {
@ -191,7 +188,7 @@ void CudaAnalysis::visit(const For* v) {
if (gpu_thread_index >= 3) {
throw std::runtime_error("support only 3D gpu_thread_index");
}
const Expr* prev = nullptr;
Expr* prev = nullptr;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
// NOLINTNEXTLINE(bugprone-branch-clone)
if (gpu_thread_extents_.size() <= gpu_thread_index) {
@ -219,13 +216,13 @@ void CudaAnalysis::visit(const For* v) {
}
}
void CudaPrinter::print_flat_alloc(const Allocate* alloc) {
void CudaPrinter::print_flat_alloc(Allocate* alloc) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Expr*> dims = alloc->dims();
std::vector<Expr*> dims = alloc->dims();
// TODO: this should be merged with the storage flattener.
int64_t flat_size = 1;
for (auto dim : dims) {
const IntImm* dim_i = dynamic_cast<const IntImm*>(dim);
IntImm* dim_i = dynamic_cast<IntImm*>(dim);
if (dim_i) {
flat_size *= dim_i->value();
} else {
@ -236,7 +233,7 @@ void CudaPrinter::print_flat_alloc(const Allocate* alloc) {
<< "[" << flat_size << "];" << std::endl;
}
void CudaPrinter::visit(const Allocate* v) {
void CudaPrinter::visit(Allocate* v) {
// TODO: handle dynamic shapes here.
if (cuda_analysis_->cross_block_bufs().count(v->buffer_var()) != 0) {
emitIndent();
@ -254,15 +251,15 @@ void CudaPrinter::visit(const Allocate* v) {
throw std::runtime_error("Encountered Alloc not local to block or thread");
}
void CudaPrinter::visit(const Free* v) {
void CudaPrinter::visit(Free* v) {
// do nothing
}
void CudaPrinter::visit(const For* v) {
void CudaPrinter::visit(For* v) {
IRPrinter::visit(v);
}
void CudaPrinter::visit(const Cast* v) {
void CudaPrinter::visit(Cast* v) {
if (v->dtype().scalar_type() == ScalarType::Half) {
os() << "__float2half(";
v->src_value()->accept(this);
@ -281,7 +278,7 @@ void CudaPrinter::visit(const Cast* v) {
os() << ")";
}
void CudaPrinter::visit(const Intrinsics* v) {
void CudaPrinter::visit(Intrinsics* v) {
if (v->op_type() == IntrinsicsOp::kRand) {
os() << "Uint32ToFloat(" << *rand_func_ << "())";
return;
@ -308,7 +305,7 @@ void CudaPrinter::visit(const Intrinsics* v) {
}
os() << func_name << "(";
for (const auto i : c10::irange(v->nparams())) {
for (auto i : c10::irange(v->nparams())) {
if (i > 0) {
os() << ", ";
}
@ -317,11 +314,11 @@ void CudaPrinter::visit(const Intrinsics* v) {
os() << ")";
}
void CudaPrinter::visit(const ExternalCall* v) {
void CudaPrinter::visit(ExternalCall* v) {
throw unimplemented_lowering(v);
}
void CudaPrinter::visit(const Load* v) {
void CudaPrinter::visit(Load* v) {
// TODO: find a better metric in using ldg or not. Support different dtypes.
// Detects whether the load target is also a store target.
// TODO: this is currently too wide. It detects whether a store-target
@ -348,7 +345,7 @@ void CudaPrinter::visit(const Load* v) {
// TODO: maybe this should be a more shared location?
// TODO: investigate how "Expr*" can be implicitly converted to "ExprHandle" as
// a bool.
static bool CheckEqual(const Expr* lhs, const Expr* rhs) {
static bool CheckEqual(Expr* lhs, Expr* rhs) {
// The fast path. Checks if the pointers are the same.
if (lhs == rhs) {
return true;
@ -362,12 +359,11 @@ class AtomicAddFuser : public IRMutator {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AtomicAddFuser(
const std::unordered_set<const Var*>& thread_local_bufs,
const std::unordered_set<Var*>& thread_local_bufs,
const GPUMetaVarRewriter& metavars)
: thread_local_bufs_(thread_local_bufs) {
const std::vector<const Expr*>& block_extents =
metavars.gpu_block_extents();
const std::vector<const Var*>& block_vars = metavars.gpu_block_vars();
const std::vector<Expr*>& block_extents = metavars.gpu_block_extents();
const std::vector<Var*>& block_vars = metavars.gpu_block_vars();
for (size_t i = 0; i < block_extents.size(); ++i) {
MetaVarExtent extent{block_extents[i], false};
if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
@ -378,9 +374,8 @@ class AtomicAddFuser : public IRMutator {
metavars_[block_vars[i]] = extent;
}
const std::vector<const Expr*>& thread_extents =
metavars.gpu_thread_extents();
const std::vector<const Var*>& thread_vars = metavars.gpu_thread_vars();
const std::vector<Expr*>& thread_extents = metavars.gpu_thread_extents();
const std::vector<Var*>& thread_vars = metavars.gpu_thread_vars();
for (size_t i = 0; i < thread_extents.size(); ++i) {
MetaVarExtent extent{thread_extents[i], false};
if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
@ -392,8 +387,8 @@ class AtomicAddFuser : public IRMutator {
}
}
Stmt* mutate(const Store* v) override {
const Buf* buf = v->buf();
Stmt* mutate(Store* v) override {
Buf* buf = v->buf();
Store* orig = const_cast<Store*>(v); // NOLINT
// Thread locals never need to be atomic.
@ -405,11 +400,11 @@ class AtomicAddFuser : public IRMutator {
if (dtype != ScalarType::Float && dtype != ScalarType::Double) {
return orig;
}
const Add* add_v = dynamic_cast<const Add*>(v->value());
Add* add_v = dynamic_cast<Add*>(v->value());
if (!add_v) {
return orig;
}
const Load* load_v = dynamic_cast<const Load*>(add_v->lhs());
Load* load_v = dynamic_cast<Load*>(add_v->lhs());
if (!load_v) {
return orig;
}
@ -427,9 +422,9 @@ class AtomicAddFuser : public IRMutator {
// TODO: this checks that the metavars occur directly as an index, but this
// is pessimistic, blockIdx.x + 1 is fine too if there is no overlapping.
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::unordered_set<const Var*> vars_to_find = nontrivial_metavars_;
for (const Expr* e : v->indices()) {
if (const Var* v = dynamic_cast<const Var*>(e)) {
std::unordered_set<Var*> vars_to_find = nontrivial_metavars_;
for (Expr* e : v->indices()) {
if (Var* v = dynamic_cast<Var*>(e)) {
vars_to_find.erase(v);
}
}
@ -443,16 +438,16 @@ class AtomicAddFuser : public IRMutator {
}
private:
const std::unordered_set<const Var*>& thread_local_bufs_;
const std::unordered_set<Var*>& thread_local_bufs_;
struct MetaVarExtent {
const Expr* expr{nullptr};
Expr* expr{nullptr};
bool trivial{false};
};
std::unordered_map<const Var*, MetaVarExtent> metavars_;
std::unordered_set<const Var*> nontrivial_metavars_;
std::unordered_map<Var*, MetaVarExtent> metavars_;
std::unordered_set<Var*> nontrivial_metavars_;
};
void CudaPrinter::visit(const Store* v) {
void CudaPrinter::visit(Store* v) {
emitIndent();
if (v->indices().empty()) {
os() << *v->base_handle() << " = ";
@ -463,7 +458,7 @@ void CudaPrinter::visit(const Store* v) {
os() << std::endl;
}
void CudaPrinter::visit(const AtomicAdd* v) {
void CudaPrinter::visit(AtomicAdd* v) {
emitIndent();
if (cuda_analysis_->thread_local_bufs().count(v->base_handle()) > 0) {
// atomicAdd only works on global and shared memory
@ -476,7 +471,7 @@ void CudaPrinter::visit(const AtomicAdd* v) {
os() << std::endl;
}
void CudaPrinter::visit(const Max* v) {
void CudaPrinter::visit(Max* v) {
if (v->dtype().is_integral()) {
os() << "max(";
} else {
@ -488,7 +483,7 @@ void CudaPrinter::visit(const Max* v) {
os() << ")";
}
void CudaPrinter::visit(const Min* v) {
void CudaPrinter::visit(Min* v) {
if (v->dtype().is_integral()) {
os() << "min(";
} else {
@ -500,7 +495,7 @@ void CudaPrinter::visit(const Min* v) {
os() << ")";
}
void CudaPrinter::visit(const IfThenElse* v) {
void CudaPrinter::visit(IfThenElse* v) {
os() << "((";
v->condition()->accept(this);
os() << ") ? ";
@ -510,7 +505,7 @@ void CudaPrinter::visit(const IfThenElse* v) {
os() << ")";
}
void CudaPrinter::visit(const Block* v) {
void CudaPrinter::visit(Block* v) {
os() << "{" << std::endl;
indent_++;
@ -523,7 +518,7 @@ void CudaPrinter::visit(const Block* v) {
os() << "}";
}
void CudaPrinter::visit(const Let* v) {
void CudaPrinter::visit(Let* v) {
emitIndent();
os() << dtypeToCppString(v->dtype());
os() << " " << *v->var() << " = ";
@ -534,7 +529,7 @@ void CudaPrinter::visit(const Let* v) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class PrioritizeLoad : public IRMutator {
public:
const Expr* mutate(const Load* v) override {
Expr* mutate(Load* v) override {
// Look at the declaration of this variable for more details.
if (nested_if_then_else_ > 0) {
return IRMutator::mutate(v);
@ -569,17 +564,17 @@ class PrioritizeLoad : public IRMutator {
}
MemLoadList& load_list = load_stack_.back();
const Var* load_new_var = new Var("v", v->dtype());
const Expr* new_value = IRMutator::mutate(v);
Var* load_new_var = new Var("v", v->dtype());
Expr* new_value = IRMutator::mutate(v);
load_list.push_back(std::make_pair(load_new_var, new_value));
return load_new_var;
}
const Expr* mutate(const Cast* v) override {
const Load* src_load = dynamic_cast<const Load*>(v->src_value());
const Expr* new_src = v->src_value()->accept_mutator(this);
const Var* new_var = dynamic_cast<const Var*>(new_src);
Expr* mutate(Cast* v) override {
Load* src_load = dynamic_cast<Load*>(v->src_value());
Expr* new_src = v->src_value()->accept_mutator(this);
Var* new_var = dynamic_cast<Var*>(new_src);
if (!src_load || !new_var) {
return new Cast(v->dtype(), new_src);
}
@ -593,27 +588,27 @@ class PrioritizeLoad : public IRMutator {
new_var = new Var("v", v->dtype());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const Expr* new_value = new Cast(v->dtype(), pair.second);
Expr* new_value = new Cast(v->dtype(), pair.second);
load_list.push_back(std::make_pair(new_var, new_value));
return new_var;
}
Stmt* mutate(const Store* v) override {
const Store* last = nested_store_;
Stmt* mutate(Store* v) override {
Store* last = nested_store_;
nested_store_ = v;
Stmt* s = IRMutator::mutate(v);
nested_store_ = last;
return s;
}
Stmt* mutate(const Let* v) override {
Stmt* mutate(Let* v) override {
nested_let_ = true;
Stmt* s = IRMutator::mutate(v);
nested_let_ = false;
return s;
}
Stmt* mutate(const Block* v) override {
Stmt* mutate(Block* v) override {
Block* v1 = const_cast<Block*>(v); // NOLINT
assert(v1);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -633,15 +628,15 @@ class PrioritizeLoad : public IRMutator {
return v1;
}
const Expr* mutate(const IfThenElse* v) override {
Expr* mutate(IfThenElse* v) override {
nested_if_then_else_++;
const Expr* new_v = IRMutator::mutate(v);
Expr* new_v = IRMutator::mutate(v);
nested_if_then_else_--;
return new_v;
}
private:
using MemLoadEntry = std::pair<const Var*, const Expr*>;
using MemLoadEntry = std::pair<Var*, Expr*>;
using MemLoadList = std::vector<MemLoadEntry>;
using MemoryLoadStack = std::vector<MemLoadList>;
@ -659,7 +654,7 @@ class PrioritizeLoad : public IRMutator {
return;
}
for (const auto& pair : load_list) {
for (auto& pair : load_list) {
Stmt* news = new Let(pair.first, pair.second);
block->insert_stmt_before(news, last);
}
@ -678,9 +673,9 @@ class PrioritizeLoad : public IRMutator {
// }
// int v2 = v + 2;
int nested_if_then_else_{0};
const Store* nested_store_{nullptr};
Store* nested_store_{nullptr};
bool nested_let_{false};
std::unordered_set<const Var*> thread_local_bufs_;
std::unordered_set<Var*> thread_local_bufs_;
};
std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
@ -716,9 +711,9 @@ bool GPUMetaVarRewriter::isFullExtent() {
return true;
}
Stmt* GPUMetaVarRewriter::mutate(const For* v) {
Stmt* GPUMetaVarRewriter::mutate(For* v) {
Stmt* body = v->body();
const Expr* old_reach = nullptr;
Expr* old_reach = nullptr;
const LoopOptions& loop_options = v->loop_options();
if (loop_options.is_gpu_block_index()) {
int gpu_block_index = loop_options.gpu_block_index();
@ -737,7 +732,7 @@ Stmt* GPUMetaVarRewriter::mutate(const For* v) {
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const Var* metaVar = gpu_block_vars_[gpu_block_index];
Var* metaVar = gpu_block_vars_[gpu_block_index];
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
} else if (loop_options.is_gpu_thread_index()) {
int gpu_thread_index = loop_options.gpu_thread_index();
@ -756,7 +751,7 @@ Stmt* GPUMetaVarRewriter::mutate(const For* v) {
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const Var* metaVar = gpu_thread_vars_[gpu_thread_index];
Var* metaVar = gpu_thread_vars_[gpu_thread_index];
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
}
@ -776,7 +771,7 @@ Stmt* GPUMetaVarRewriter::mutate(const For* v) {
return v->cloneWithNewBody(body);
}
Stmt* GPUMetaVarRewriter::mutate(const Block* v) {
Stmt* GPUMetaVarRewriter::mutate(Block* v) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<Segment> innerSegments;
Segment current;
@ -891,7 +886,7 @@ Stmt* GPUMetaVarRewriter::mutate(const Block* v) {
static std::ostream& operator<<(
std::ostream& out,
const std::vector<const Expr*>& exprs) {
const std::vector<Expr*>& exprs) {
size_t i = 0;
for (auto expr : exprs) {
if (i++ > 0) {
@ -983,7 +978,7 @@ void CudaCodeGen::Initialize() {
os() << ", ";
}
const BufferArg& buffer_arg = buffer_args[i];
const Var* var = buffer_arg.var();
Var* var = buffer_arg.var();
Dtype dtype = buffer_arg.dtype();
os() << printer_->dtypeToCppString(dtype)
@ -991,9 +986,9 @@ void CudaCodeGen::Initialize() {
<< name_manager()->get_unique_name(var);
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const Var* rand_seed;
Var* rand_seed;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const Var* rand_offset;
Var* rand_offset;
if (has_random_) {
// TODO: switch to kUint64 when it is available.
rand_seed = new Var("rand_seed", kInt);
@ -1006,11 +1001,11 @@ void CudaCodeGen::Initialize() {
os() << std::endl;
if (has_random_) {
const Var* idx = new Var("idx", kInt);
Var* idx = new Var("idx", kInt);
os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;"
<< std::endl;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const Var* rand_func = printer_->rand_func();
Var* rand_func = printer_->rand_func();
os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
<< *rand_offset << ");" << std::endl;
os() << std::endl;
@ -1041,7 +1036,7 @@ void CudaCodeGen::Initialize() {
os() << "}";
// Check that all block extents had been set.
const std::vector<const Expr*>& gpu_block_extents =
const std::vector<Expr*>& gpu_block_extents =
metavar_rewriter_->gpu_block_extents();
for (size_t i = 0; i < gpu_block_extents.size(); i++) {
if (!gpu_block_extents[i]) {
@ -1067,9 +1062,9 @@ void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
auto const& buffer_args = this->buffer_args();
// TODO: move as much of this into the constructors.
const std::vector<const Expr*>& gpu_block_extents =
const std::vector<Expr*>& gpu_block_extents =
metavar_rewriter_->gpu_block_extents();
const std::vector<const Expr*>& gpu_thread_extents =
const std::vector<Expr*>& gpu_thread_extents =
metavar_rewriter_->gpu_thread_extents();
if (gpu_block_extents.size() > 3 || gpu_thread_extents.size() > 3) {
throw malformed_input(
@ -1148,7 +1143,7 @@ void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
ptr_to_args[buffer_args.size() + 1] = &rand_offset;
}
const auto prior_device = at::cuda::current_device();
auto prior_device = at::cuda::current_device();
if (prior_device != this->device().index()) {
at::cuda::set_device(this->device().index());
}
@ -1207,7 +1202,7 @@ void CudaCodeGen::CompileToNVRTC(
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
// Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
// properly in some scenarios
const auto prior_device = at::cuda::current_device();
auto prior_device = at::cuda::current_device();
if (prior_device != this->device().index()) {
at::cuda::set_device(this->device().index());
}
@ -1259,8 +1254,7 @@ void CudaCodeGen::CompileToNVRTC(
"--std=c++14", compute.c_str(), "-default-device"};
#endif
const auto result =
nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
auto result = nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
if (result != NVRTC_SUCCESS) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t logsize;
@ -1284,15 +1278,14 @@ void CudaCodeGen::CompileToNVRTC(
#if CUDA_VERSION >= 11010
// compile_to_sass determines whether we are generating SASS or PTX, hence
// the different API.
const auto getSize = compile_to_sass
auto getSize = compile_to_sass
? at::globalContext().getNVRTC().nvrtcGetCUBINSize
: at::globalContext().getNVRTC().nvrtcGetPTXSize;
const auto getFunc = compile_to_sass
? at::globalContext().getNVRTC().nvrtcGetCUBIN
: at::globalContext().getNVRTC().nvrtcGetPTX;
auto getFunc = compile_to_sass ? at::globalContext().getNVRTC().nvrtcGetCUBIN
: at::globalContext().getNVRTC().nvrtcGetPTX;
#else
const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
#endif
AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
ptx.resize(ptx_size);

View File

@ -26,41 +26,41 @@ class CudaAnalysis : public IRVisitor {
gpu_block_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
gpu_thread_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
}
bool is_buf_store_target(const Buf* buf) const {
bool is_buf_store_target(Buf* buf) const {
return store_targets_.count(buf) > 0;
}
const std::unordered_set<const Var*>& thread_local_bufs() const {
const std::unordered_set<Var*>& thread_local_bufs() const {
return thread_local_bufs_;
}
const std::unordered_set<const Var*>& cross_block_bufs() const {
const std::unordered_set<Var*>& cross_block_bufs() const {
return cross_block_bufs_;
}
const std::vector<const Expr*>& gpu_block_extents() const {
const std::vector<Expr*>& gpu_block_extents() const {
return gpu_block_extents_;
}
const std::vector<const Expr*>& gpu_thread_extents() const {
const std::vector<Expr*>& gpu_thread_extents() const {
return gpu_thread_extents_;
}
private:
void visit(const Store* v) override {
void visit(Store* v) override {
store_targets_.insert(v->buf());
}
void visit(const Allocate* v) override;
void visit(const Free* v) override;
void visit(const For* v) override;
void visit(Allocate* v) override;
void visit(Free* v) override;
void visit(For* v) override;
std::unordered_set<const Buf*> store_targets_;
std::unordered_set<const Var*> thread_local_bufs_;
std::unordered_set<const Var*> cross_block_bufs_;
std::unordered_set<Buf*> store_targets_;
std::unordered_set<Var*> thread_local_bufs_;
std::unordered_set<Var*> cross_block_bufs_;
std::vector<const Expr*> gpu_block_extents_;
std::vector<const Expr*> gpu_thread_extents_;
std::vector<Expr*> gpu_block_extents_;
std::vector<Expr*> gpu_thread_extents_;
};
// An IRMutator that replaces binding loop options with Cuda metavars, and masks
@ -87,22 +87,22 @@ class GPUMetaVarRewriter : public IRMutator {
current_thread_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
}
Stmt* mutate(const For* v) override;
Stmt* mutate(const Block* v) override;
Stmt* mutate(For* v) override;
Stmt* mutate(Block* v) override;
const std::vector<const Var*>& gpu_block_vars() const {
const std::vector<Var*>& gpu_block_vars() const {
return gpu_block_vars_;
}
const std::vector<const Var*>& gpu_thread_vars() const {
const std::vector<Var*>& gpu_thread_vars() const {
return gpu_thread_vars_;
}
const std::vector<const Expr*>& gpu_block_extents() const {
const std::vector<Expr*>& gpu_block_extents() const {
return cuda_analysis_->gpu_block_extents();
}
const std::vector<const Expr*>& gpu_thread_extents() const {
const std::vector<Expr*>& gpu_thread_extents() const {
return cuda_analysis_->gpu_thread_extents();
}
@ -136,11 +136,11 @@ class GPUMetaVarRewriter : public IRMutator {
// parameters.
bool isFullExtent();
std::vector<const Var*> gpu_block_vars_;
std::vector<const Var*> gpu_thread_vars_;
std::vector<Var*> gpu_block_vars_;
std::vector<Var*> gpu_thread_vars_;
std::vector<const Expr*> current_block_reach_;
std::vector<const Expr*> current_thread_reach_;
std::vector<Expr*> current_block_reach_;
std::vector<Expr*> current_thread_reach_;
const CudaAnalysis* cuda_analysis_;
};
@ -158,24 +158,24 @@ class CudaPrinter : public IRPrinter {
}
}
void visit(const Cast* v) override;
void visit(const Intrinsics* v) override;
void visit(const For* v) override;
void visit(Cast* v) override;
void visit(Intrinsics* v) override;
void visit(For* v) override;
void visit(const Load* v) override;
void visit(const Store* v) override;
void visit(const AtomicAdd* v) override;
void visit(const Max* v) override;
void visit(const Min* v) override;
void visit(const IfThenElse* v) override;
void visit(const Block* v) override;
void visit(const Allocate* v) override;
void visit(const Free* v) override;
void visit(const Let* v) override;
void visit(Load* v) override;
void visit(Store* v) override;
void visit(AtomicAdd* v) override;
void visit(Max* v) override;
void visit(Min* v) override;
void visit(IfThenElse* v) override;
void visit(Block* v) override;
void visit(Allocate* v) override;
void visit(Free* v) override;
void visit(Let* v) override;
void visit(const ExternalCall* v) override;
void visit(ExternalCall* v) override;
const Var* rand_func() const {
Var* rand_func() const {
return rand_func_;
}
@ -185,10 +185,10 @@ class CudaPrinter : public IRPrinter {
using IRPrinter::visit;
private:
const Var* rand_func_;
Var* rand_func_;
const CudaAnalysis* cuda_analysis_;
void print_flat_alloc(const Allocate* alloc);
void print_flat_alloc(Allocate* alloc);
};
// Construct Cuda C from the buffer and tensor input, and invoke the kernel
@ -233,11 +233,11 @@ class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen {
c10::optional<c10::Device> device_opt,
c10::optional<bool> pin_memory_opt) override;
const std::vector<const Expr*>& gpu_block_extents() const {
const std::vector<Expr*>& gpu_block_extents() const {
return cuda_analysis_->gpu_block_extents();
}
const std::vector<const Expr*>& gpu_thread_extents() const {
const std::vector<Expr*>& gpu_thread_extents() const {
return cuda_analysis_->gpu_thread_extents();
}

View File

@ -59,14 +59,14 @@ class SimpleIREvaluatorImpl : public IRVisitor {
~SimpleIREvaluatorImpl() override = default;
void bindBuf(const Buf* buf, void* ptr) {
void bindBuf(Buf* buf, void* ptr) {
buffer_mapping_[buf] = ptr;
}
void bindVar(const Var* var, const Value& val) {
void bindVar(Var* var, const Value& val) {
eval_context_[var] = val;
}
Value evaluateExpr(const Expr* e) {
Value evaluateExpr(Expr* e) {
e->accept(this);
return value_;
}
@ -81,45 +81,45 @@ class SimpleIREvaluatorImpl : public IRVisitor {
internal_buffers_.clear();
}
TORCH_API void visit(const Add* v) override {
TORCH_API void visit(Add* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Sub* v) override {
TORCH_API void visit(Sub* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Mul* v) override {
TORCH_API void visit(Mul* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Div* v) override {
TORCH_API void visit(Div* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Mod* v) override {
TORCH_API void visit(Mod* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Max* v) override {
TORCH_API void visit(Max* v) override {
visit_binary_op(v, v->propagate_nans());
}
TORCH_API void visit(const Min* v) override {
TORCH_API void visit(Min* v) override {
visit_binary_op(v, v->propagate_nans());
}
TORCH_API void visit(const And* v) override {
TORCH_API void visit(And* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Or* v) override {
TORCH_API void visit(Or* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Xor* v) override {
TORCH_API void visit(Xor* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Lshift* v) override {
TORCH_API void visit(Lshift* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Rshift* v) override {
TORCH_API void visit(Rshift* v) override {
visit_binary_op(v);
}
void visit(const CompareSelect* v) override {
void visit(CompareSelect* v) override {
visit_compare_select_op(v, v->compare_select_op());
}
@ -156,7 +156,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
std::vector<T> lhs_v = lhs.as_vec<T>();
std::vector<T> rhs_v = rhs.as_vec<T>();
std::vector<T> result_v(lhs_v.size());
for (const auto i : c10::irange(lhs_v.size())) {
for (auto i : c10::irange(lhs_v.size())) {
switch (op_type) {
case IRNodeType::kAdd:
result_v[i] = lhs_v[i] + rhs_v[i];
@ -195,7 +195,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
std::vector<T> lhs_v = lhs.as_vec<T>();
std::vector<T> rhs_v = rhs.as_vec<T>();
std::vector<T> result_v(lhs_v.size());
for (const auto i : c10::irange(lhs_v.size())) {
for (auto i : c10::irange(lhs_v.size())) {
switch (op_type) {
case IRNodeType::kAnd:
result_v[i] = lhs_v[i] & rhs_v[i];
@ -222,7 +222,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
std::vector<T> lhs_v = lhs.as_vec<T>();
std::vector<T> rhs_v = rhs.as_vec<T>();
std::vector<T> result_v(lhs_v.size());
for (const auto i : c10::irange(lhs_v.size())) {
for (auto i : c10::irange(lhs_v.size())) {
switch (op_type) {
case IRNodeType::kLshift: {
typename std::make_unsigned<T>::type a =
@ -253,7 +253,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
std::vector<R> ret_val1_v = retval1.as_vec<R>();
std::vector<R> ret_val2_v = retval2.as_vec<R>();
std::vector<R> result_v(lhs_v.size());
for (const auto i : c10::irange(lhs_v.size())) {
for (auto i : c10::irange(lhs_v.size())) {
switch (cmp_op) {
case CompareSelectOperation::kEQ:
result_v[i] = (lhs_v[i] == rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
@ -282,7 +282,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
template <typename Op>
void visit_binary_op(const BinaryOpNode<Op>* v, bool option = false) {
void visit_binary_op(BinaryOpNode<Op>* v, bool option = false) {
v->lhs()->accept(this);
Value lhs_v = value_;
v->rhs()->accept(this);
@ -365,7 +365,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
void visit_compare_select_op(
const CompareSelect* v,
CompareSelect* v,
CompareSelectOperation cmp_op) {
v->lhs()->accept(this);
Value lhs_v = value_;
@ -401,8 +401,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
#undef IMM_VISIT
TORCH_API void visit(const Block* v) override {
const Block* last = scope_;
TORCH_API void visit(Block* v) override {
Block* last = scope_;
scope_ = v;
for (Stmt* s : v->stmts()) {
s->accept(this);
@ -410,7 +410,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
auto it = var_by_scope_.find(v);
if (it != var_by_scope_.end()) {
for (const Expr* v : it->second) {
for (Expr* v : it->second) {
eval_context_.erase(v);
}
var_by_scope_.erase(it);
@ -419,7 +419,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
scope_ = last;
}
TORCH_API void visit(const Var* v) override {
TORCH_API void visit(Var* v) override {
auto iter = eval_context_.find(v);
if (iter == eval_context_.end()) {
throw malformed_input("could not find Var in context", v);
@ -456,8 +456,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
}
TORCH_API void visit(const Cast* v) override {
const Expr* src_value = v->src_value();
TORCH_API void visit(Cast* v) override {
Expr* src_value = v->src_value();
src_value->accept(this);
Dtype dst_dtype = v->dtype();
Dtype src_dtype = src_value->dtype();
@ -507,8 +507,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
}
TORCH_API void visit(const BitCast* v) override {
const Expr* src_value = v->src_value();
TORCH_API void visit(BitCast* v) override {
Expr* src_value = v->src_value();
src_value->accept(this);
Dtype dst_dtype = v->dtype();
Dtype src_dtype = src_value->dtype();
@ -530,8 +530,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
}
TORCH_API void visit(const For* v) override {
const Expr* var_node = v->var();
TORCH_API void visit(For* v) override {
Expr* var_node = v->var();
v->start()->accept(this);
int start = value_.as<int>();
v->stop()->accept(this);
@ -549,7 +549,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
eval_context_.erase(var_node);
}
TORCH_API void visit(const Ramp* v) override {
TORCH_API void visit(Ramp* v) override {
v->base()->accept(this);
int base = value().as<int>();
v->stride()->accept(this);
@ -557,14 +557,14 @@ class SimpleIREvaluatorImpl : public IRVisitor {
int lanes = v->lanes();
std::vector<int> values(lanes);
for (const auto i : c10::irange(lanes)) {
for (auto i : c10::irange(lanes)) {
values[i] = base + i * stride;
}
value_ = Value(values);
}
TORCH_API void visit(const Broadcast* v) override {
TORCH_API void visit(Broadcast* v) override {
v->value()->accept(this);
Value value = this->value();
int lanes = v->lanes();
@ -581,7 +581,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
}
TORCH_API void visit(const IfThenElse* v) override {
TORCH_API void visit(IfThenElse* v) override {
v->condition()->accept(this);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool cond_v;
@ -605,26 +605,26 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
}
TORCH_API void visit(const Load* v) override {
TORCH_API void visit(Load* v) override {
auto iter = buffer_mapping_.find(v->buf());
if (iter == buffer_mapping_.end()) {
throw malformed_input("could not find base node in Load", v);
}
void* ptr = iter->second;
const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
flat_idx->accept(this);
std::vector<int> index = value().as_vec<int>();
ScalarType v_sdtype = v->dtype().scalar_type();
switch (v_sdtype) {
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: { \
Type* ptr##Name = static_cast<Type*>(ptr); \
std::vector<Type> v(index.size()); \
for (const auto i : c10::irange(index.size())) { \
v[i] = ptr##Name[index[i]]; \
} \
value_ = Value(v); \
#define TYPE_CASE(Type, Name) \
case ScalarType::Name: { \
Type* ptr##Name = static_cast<Type*>(ptr); \
std::vector<Type> v(index.size()); \
for (auto i : c10::irange(index.size())) { \
v[i] = ptr##Name[index[i]]; \
} \
value_ = Value(v); \
} break;
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
#undef TYPE_CASE
@ -633,7 +633,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
}
TORCH_API void visit(const Store* v) override {
TORCH_API void visit(Store* v) override {
auto iter = buffer_mapping_.find(v->buf());
if (iter == buffer_mapping_.end()) {
throw malformed_input("could not find base node in Store", v);
@ -641,7 +641,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
void* ptr = iter->second;
const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
flat_idx->accept(this);
std::vector<int> index = value().as_vec<int>();
ScalarType v_sdtype = v->value()->dtype().scalar_type();
@ -655,7 +655,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
throw malformed_input("value size mismatch in Store", v); \
} \
Type* ptr##Name = static_cast<Type*>(ptr); \
for (const auto i : c10::irange(index.size())) { \
for (auto i : c10::irange(index.size())) { \
ptr##Name[index[i]] = value[i]; \
} \
} break;
@ -666,13 +666,13 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
}
void visit(const ExternalCall* v) override {
void visit(ExternalCall* v) override {
auto& func_registry = getNNCFunctionRegistry();
if (!func_registry.count(v->func_name())) {
throw unimplemented_lowering(v);
}
std::vector<const Buf*> bufs(v->buf_args());
std::vector<Buf*> bufs(v->buf_args());
bufs.insert(bufs.begin(), v->buf());
std::vector<void*> buf_ptrs;
@ -681,7 +681,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
std::vector<int8_t> buf_dtypes;
std::vector<int64_t> extra_args;
for (const Buf* b : bufs) {
for (Buf* b : bufs) {
auto iter = buffer_mapping_.find(b);
if (iter == buffer_mapping_.end()) {
throw malformed_input("could not find buf", v);
@ -690,12 +690,12 @@ class SimpleIREvaluatorImpl : public IRVisitor {
buf_ptrs.push_back(iter->second);
buf_ranks.push_back(b->dims().size());
buf_dtypes.push_back((int8_t)b->dtype().scalar_type());
for (const Expr* dim_expr : b->dims()) {
for (Expr* dim_expr : b->dims()) {
dim_expr->accept(this);
buf_dims.push_back(value().as<int>());
}
}
for (const Expr* a : v->args()) {
for (Expr* a : v->args()) {
a->accept(this);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t val;
@ -722,9 +722,9 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
template <typename TReturn, typename TInput>
void visit_intrinsics_helper(const Intrinsics* v) {
void visit_intrinsics_helper(Intrinsics* v) {
std::vector<Value> values(v->nparams());
for (const auto i : c10::irange(v->nparams())) {
for (auto i : c10::irange(v->nparams())) {
v->param(i)->accept(this);
values[i] = this->value();
}
@ -746,18 +746,18 @@ class SimpleIREvaluatorImpl : public IRVisitor {
std::vector<TReturn> result(v1.size(), -1);
if (values.size() == 1ULL) {
for (const auto i : c10::irange(v1.size())) {
for (auto i : c10::irange(v1.size())) {
result[i] = compute_intrinsics<TReturn>(v->op_type(), v1[i]);
}
} else {
for (const auto i : c10::irange(v1.size())) {
for (auto i : c10::irange(v1.size())) {
result[i] = compute_intrinsics<TReturn>(v->op_type(), v1[i], v2[i]);
}
}
value_ = Value(result);
}
TORCH_API void visit(const Intrinsics* v) override {
TORCH_API void visit(Intrinsics* v) override {
auto ty = v->dtype().scalar_type();
if (v->op_type() == kIsNan) {
auto inp_dtype = v->params().at(0)->dtype().scalar_type();
@ -782,15 +782,15 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
}
void visit(const Allocate* v) override {
const Buf* b = v->buf();
std::vector<const Expr*> dims = b->dims();
void visit(Allocate* v) override {
Buf* b = v->buf();
std::vector<Expr*> dims = b->dims();
int total_byte_size = b->dtype().byte_size();
for (auto& dim : dims) {
dim->accept(this);
total_byte_size *= value_.as<int>();
}
const auto int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int);
auto int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int);
std::unique_ptr<std::vector<int>> buffer(new std::vector<int>(int_count));
auto iter = buffer_mapping_.find(b);
if (iter != buffer_mapping_.end() && iter->second != nullptr) {
@ -802,8 +802,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
internal_buffers_.insert(std::make_pair(b, std::move(buffer)));
}
void visit(const Free* v) override {
const Buf* b = v->buf();
void visit(Free* v) override {
Buf* b = v->buf();
int count = internal_buffers_.erase(b);
if (count == 0) {
throw std::runtime_error(
@ -813,12 +813,12 @@ class SimpleIREvaluatorImpl : public IRVisitor {
buffer_mapping_.erase(b);
}
void visit(const Let* v) override {
void visit(Let* v) override {
var_by_scope_[scope_].push_back(v->var());
bindVar(v->var(), evaluateExpr(v->value()));
}
void visit(const Cond* v) override {
void visit(Cond* v) override {
v->condition()->accept(this);
if (value().as<int>()) {
if (v->true_stmt()) {
@ -950,12 +950,11 @@ class SimpleIREvaluatorImpl : public IRVisitor {
}
Value value_;
const Block* scope_;
std::unordered_map<const Expr*, Value> eval_context_;
std::unordered_map<const Block*, std::vector<const Expr*>> var_by_scope_;
std::unordered_map<const Buf*, void*> buffer_mapping_;
std::unordered_map<const Buf*, std::unique_ptr<std::vector<int>>>
internal_buffers_;
Block* scope_;
std::unordered_map<Expr*, Value> eval_context_;
std::unordered_map<Block*, std::vector<Expr*>> var_by_scope_;
std::unordered_map<Buf*, void*> buffer_mapping_;
std::unordered_map<Buf*, std::unique_ptr<std::vector<int>>> internal_buffers_;
};
SimpleIREvaluator::SimpleIREvaluator(
@ -984,7 +983,7 @@ void SimpleIREvaluator::call_raw(const std::vector<void*>& args) {
if (args.size() != buffer_args().size()) {
throw malformed_input("bad args in IREvaluator call");
}
for (const auto i : c10::irange(args.size())) {
for (auto i : c10::irange(args.size())) {
bindArg(buffer_args()[i], args[i]);
}
stmt()->accept(&*impl_);
@ -1012,7 +1011,7 @@ void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) {
}
}
void SimpleIREvaluator::bindVar(const Var* v, const Expr* e) {
void SimpleIREvaluator::bindVar(Var* v, Expr* e) {
impl_->bindVar(v, impl_->evaluateExpr(e));
}

View File

@ -114,7 +114,7 @@ class TORCH_API SimpleIREvaluator : public CodeGen {
call(args);
}
void bindVar(const Var* v, const Expr* e);
void bindVar(Var* v, Expr* e);
Value value() const;
private:
@ -145,8 +145,8 @@ class ExprEval {
std::vector<BufferArg> buffer_args_extended = buffer_args;
Placeholder ret_buf("ret_val", dtype_, {1});
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Expr*> indices;
const Expr* zero = new IntImm(0);
std::vector<Expr*> indices;
Expr* zero = new IntImm(0);
for (size_t i = 0; i < ret_buf.data()->ndim(); i++) {
indices.push_back(zero);
}
@ -167,7 +167,7 @@ class ExprEval {
call(call_args);
}
void bindVar(const Var* v, const Expr* e) {
void bindVar(Var* v, Expr* e) {
codegen_->bindVar(v, e);
}
@ -253,7 +253,7 @@ class ExprEval {
Value ret_value_;
};
inline const Expr* Substitute(const Expr* expr, const VarMapping& var_mapping) {
inline Expr* Substitute(Expr* expr, const VarMapping& var_mapping) {
VarSubMutator var_sub(var_mapping);
return expr->accept_mutator(&var_sub);
}

View File

@ -43,9 +43,9 @@ class unimplemented_lowering : public std::runtime_error {
public:
explicit unimplemented_lowering()
: std::runtime_error("UNIMPLEMENTED LOWERING") {}
explicit unimplemented_lowering(const Expr* expr)
explicit unimplemented_lowering(Expr* expr)
: std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(expr)) {}
explicit unimplemented_lowering(const Stmt* stmt)
explicit unimplemented_lowering(Stmt* stmt)
: std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(stmt)) {}
};
@ -54,14 +54,14 @@ class malformed_input : public std::runtime_error {
explicit malformed_input() : std::runtime_error("MALFORMED INPUT") {}
explicit malformed_input(const std::string& err)
: std::runtime_error("MALFORMED INPUT: " + err) {}
explicit malformed_input(const Expr* expr)
explicit malformed_input(Expr* expr)
: std::runtime_error("MALFORMED INPUT: " + std::to_string(expr)) {}
explicit malformed_input(const std::string& err, const Expr* expr)
explicit malformed_input(const std::string& err, Expr* expr)
: std::runtime_error(
"MALFORMED INPUT: " + err + " - " + std::to_string(expr)) {}
explicit malformed_input(const Stmt* stmt)
explicit malformed_input(Stmt* stmt)
: std::runtime_error("MALFORMED INPUT: " + std::to_string(stmt)) {}
explicit malformed_input(const std::string& err, const Stmt* stmt)
explicit malformed_input(const std::string& err, Stmt* stmt)
: std::runtime_error(
"MALFORMED INPUT: " + err + " - " + std::to_string(stmt)) {}
};
@ -71,14 +71,14 @@ class malformed_ir : public std::runtime_error {
explicit malformed_ir() : std::runtime_error("MALFORMED IR") {}
explicit malformed_ir(const std::string& err)
: std::runtime_error("MALFORMED IR: " + err) {}
explicit malformed_ir(const Expr* expr)
explicit malformed_ir(Expr* expr)
: std::runtime_error("MALFORMED IR: " + std::to_string(expr)) {}
explicit malformed_ir(const std::string& err, const Expr* expr)
explicit malformed_ir(const std::string& err, Expr* expr)
: std::runtime_error(
"MALFORMED IR: " + err + " - " + std::to_string(expr)) {}
explicit malformed_ir(const Stmt* stmt)
explicit malformed_ir(Stmt* stmt)
: std::runtime_error("MALFORMED IR: " + std::to_string(stmt)) {}
explicit malformed_ir(const std::string& err, const Stmt* stmt)
explicit malformed_ir(const std::string& err, Stmt* stmt)
: std::runtime_error(
"MALFORMED IR: " + err + " - " + std::to_string(stmt)) {}
};

View File

@ -42,8 +42,8 @@ class TORCH_API Expr : public KernelScopedObject {
Dtype dtype() const {
return dtype_;
}
virtual void accept(IRVisitor* visitor) const = 0;
virtual const Expr* accept_mutator(IRMutator* mutator) const = 0;
virtual void accept(IRVisitor* visitor) = 0;
virtual Expr* accept_mutator(IRMutator* mutator) = 0;
IRNodeType expr_type() const {
return expr_type_;
@ -64,10 +64,10 @@ template <class Op, class Base = Expr>
class ExprNode : public Base {
public:
using ExprNodeBase = ExprNode<Op>;
void accept(IRVisitor* visitor) const override {
visitor->visit(static_cast<const Op*>(this));
void accept(IRVisitor* visitor) override {
visitor->visit(static_cast<Op*>(this));
}
const Expr* accept_mutator(IRMutator* mutator) const override;
Expr* accept_mutator(IRMutator* mutator) override;
// pass the constructor to the base class
using Base::Base;
};
@ -77,7 +77,7 @@ class ExprNode : public Base {
class TORCH_API ExprHandle {
public:
ExprHandle() = default;
explicit ExprHandle(const Expr* node)
explicit ExprHandle(Expr* node)
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
: base_expr_node_(const_cast<Expr*>(node)) {}
@ -85,7 +85,7 @@ class TORCH_API ExprHandle {
return base_expr_node_;
}
const Expr* node() const {
Expr* node() const {
return base_expr_node_;
}
@ -103,7 +103,7 @@ class TORCH_API ExprHandle {
}
template <class Op>
const Op* AsNode() const {
Op* AsNode() const {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
return const_cast<ExprHandle*>(this)->AsNode<Op>();
}
@ -173,7 +173,7 @@ class TORCH_API Buf : public ExprNode<Buf> {
static ExprHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
// TODO: unique_name
const Var* base_handle() const {
Var* base_handle() const {
return base_handle_;
}
void set_base_handle(Var* base_handle) {
@ -189,16 +189,16 @@ class TORCH_API Buf : public ExprNode<Buf> {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Buf(const std::string& name_hint,
const std::vector<const Expr*>& dims,
const std::vector<Expr*>& dims,
Dtype dtype,
const Expr* initializer = nullptr)
Expr* initializer = nullptr)
: Buf(new Var(name_hint, kHandle), dims, dtype, initializer) {}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Buf(Var* var,
std::vector<const Expr*> dims,
std::vector<Expr*> dims,
Dtype dtype,
const Expr* initializer = nullptr)
Expr* initializer = nullptr)
: ExprNodeBase(dtype, kPrimitive),
base_handle_(var),
dims_(std::move(dims)),
@ -209,20 +209,20 @@ class TORCH_API Buf : public ExprNode<Buf> {
size_t ndim() const {
return dims_.size();
}
const Expr* dim(size_t index) const {
Expr* dim(size_t index) const {
if (index >= ndim()) {
throw out_of_range_index();
}
return dims_[index];
}
std::vector<const Expr*> dims() const {
std::vector<Expr*> dims() const {
return dims_;
}
void set_dims(std::vector<const Expr*> dims) {
void set_dims(std::vector<Expr*> dims) {
dims_ = dims;
};
const Expr* initializer() const {
Expr* initializer() const {
return initializer_;
};
@ -237,8 +237,8 @@ class TORCH_API Buf : public ExprNode<Buf> {
private:
Var* base_handle_;
std::vector<const Expr*> dims_;
const Expr* initializer_;
std::vector<Expr*> dims_;
Expr* initializer_;
};
class TORCH_API BufHandle : public ExprHandle {
@ -254,9 +254,9 @@ class TORCH_API BufHandle : public ExprHandle {
explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}
explicit BufHandle(const Buf* node) : ExprHandle(node) {}
const Buf* node() const {
return static_cast<const Buf*>(ExprHandle::node());
explicit BufHandle(Buf* node) : ExprHandle(node) {}
Buf* node() const {
return static_cast<Buf*>(ExprHandle::node());
}
Buf* node() {
return static_cast<Buf*>(ExprHandle::node());
@ -303,9 +303,9 @@ class TORCH_API VarHandle : public ExprHandle {
explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {}
VarHandle(const std::string& name_hint, Dtype dtype)
: ExprHandle(Var::make(name_hint, dtype)) {}
explicit VarHandle(const Var* node) : ExprHandle(node) {}
const Var* node() const {
return static_cast<const Var*>(ExprHandle::node());
explicit VarHandle(Var* node) : ExprHandle(node) {}
Var* node() const {
return static_cast<Var*>(ExprHandle::node());
}
bool operator==(const VarHandle& other) const {
return this->node() == other.node();
@ -323,7 +323,7 @@ class TORCH_API VarHandle : public ExprHandle {
};
template <class Op, class Base>
const Expr* ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) const {
Expr* ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
ExprNode* this_mutable = const_cast<ExprNode*>(this);
return mutator->mutate(static_cast<Op*>(this_mutable));

View File

@ -21,18 +21,18 @@ std::vector<at::Tensor> constructTensors(
std::vector<std::vector<int64_t>> buf_dims_vec;
std::vector<c10::ScalarType> buf_dtypes_vec;
int64_t buf_dims_idx = 0;
for (const auto i : c10::irange(bufs_num)) {
for (auto i : c10::irange(bufs_num)) {
buf_data_vec.push_back(buf_data[i]);
buf_dims_vec.emplace_back();
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,clang-analyzer-deadcode.DeadStores)
for (const auto dim : c10::irange(buf_ranks[i])) {
for (auto dim : c10::irange(buf_ranks[i])) {
buf_dims_vec[i].push_back(buf_dims[buf_dims_idx++]);
}
buf_dtypes_vec.push_back(static_cast<c10::ScalarType>(buf_dtypes[i]));
}
std::vector<at::Tensor> tensors;
for (const auto i : c10::irange(buf_data_vec.size())) {
for (auto i : c10::irange(buf_data_vec.size())) {
auto options = at::TensorOptions()
.dtype(buf_dtypes_vec[i])
.layout(at::kStrided)
@ -208,7 +208,7 @@ void nnc_prepacked_linear_clamp_run(
constructTensors(bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_dtypes);
const at::Tensor& x = tensors[1];
const auto context = reinterpret_cast<LinearOpContext*>(buf_data[2]);
auto context = reinterpret_cast<LinearOpContext*>(buf_data[2]);
at::Tensor output = context->run(x);
memcpy(
buf_data[0], output.data_ptr(), output.element_size() * output.numel());
@ -228,7 +228,7 @@ void nnc_prepacked_conv2d_clamp_run(
constructTensors(bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_dtypes);
const at::Tensor& x = tensors[1];
const auto context = reinterpret_cast<Conv2dOpContext*>(buf_data[2]);
auto context = reinterpret_cast<Conv2dOpContext*>(buf_data[2]);
at::Tensor output = context->run(x);
memcpy(
buf_data[0], output.data_ptr(), output.element_size() * output.numel());

View File

@ -22,12 +22,12 @@ class HalfChecker : public IRVisitor {
return hasHalf_;
}
void visit(const Load* v) override {
void visit(Load* v) override {
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
IRVisitor::visit(v);
}
void visit(const Store* v) override {
void visit(Store* v) override {
hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half;
IRVisitor::visit(v);
}
@ -36,7 +36,7 @@ class HalfChecker : public IRVisitor {
hasHalf_ = true;
}
void visit(const Cast* v) override {
void visit(Cast* v) override {
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
IRVisitor::visit(v);
}
@ -47,21 +47,21 @@ class HalfChecker : public IRVisitor {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class HalfRewriter : public IRMutator {
const Expr* mutate(const Load* v) override {
const Expr* child = IRMutator::mutate(v);
Expr* mutate(Load* v) override {
Expr* child = IRMutator::mutate(v);
if (child->dtype().scalar_type() != ScalarType::Half) {
return child;
}
const Expr* ret =
Expr* ret =
new Cast(child->dtype().cloneWithScalarType(ScalarType::Float), child);
inserted_half_casts_.insert(ret);
return ret;
}
Stmt* mutate(const Store* v) override {
const Expr* new_val = v->value()->accept_mutator(this);
Stmt* mutate(Store* v) override {
Expr* new_val = v->value()->accept_mutator(this);
Dtype newType = v->value()->dtype();
if (newType.scalar_type() == ScalarType::Half) {
@ -73,12 +73,12 @@ class HalfRewriter : public IRMutator {
return new Store(v->buf(), v->indices(), new_val);
}
const Expr* mutate(const HalfImm* v) override {
Expr* mutate(HalfImm* v) override {
return new Cast(kFloat, v);
}
const Expr* mutate(const Cast* v) override {
const Expr* child = v->src_value()->accept_mutator(this);
Expr* mutate(Cast* v) override {
Expr* child = v->src_value()->accept_mutator(this);
// just don't allow half casts we didn't insert.
if (v->dtype().scalar_type() == ScalarType::Half) {
@ -88,7 +88,7 @@ class HalfRewriter : public IRMutator {
}
// Remove Half(Float()) and friends.
const Cast* cast_child = dynamic_cast<const Cast*>(child);
Cast* cast_child = dynamic_cast<Cast*>(child);
if (cast_child) {
if (v->dtype().is_floating_point() &&
cast_child->dtype().is_floating_point()) {
@ -102,10 +102,10 @@ class HalfRewriter : public IRMutator {
return new Cast(v->dtype(), child);
}
Stmt* mutate(const Let* v) override {
Stmt* mutate(Let* v) override {
if (v->dtype().scalar_type() == ScalarType::Half) {
const Var* load_new_var = new Var(v->var()->name_hint(), kFloat);
const Expr* new_value = new Cast(
Var* load_new_var = new Var(v->var()->name_hint(), kFloat);
Expr* new_value = new Cast(
v->dtype().cloneWithScalarType(ScalarType::Float),
v->value()->accept_mutator(this));
var_map[v->var()] = load_new_var;
@ -116,7 +116,7 @@ class HalfRewriter : public IRMutator {
return IRMutator::mutate(v);
}
const Expr* mutate(const Var* v) override {
Expr* mutate(Var* v) override {
auto it = var_map.find(v);
if (it != var_map.end()) {
return it->second;
@ -126,8 +126,8 @@ class HalfRewriter : public IRMutator {
}
private:
std::unordered_set<const Expr*> inserted_half_casts_;
std::unordered_map<const Var*, const Var*> var_map;
std::unordered_set<Expr*> inserted_half_casts_;
std::unordered_map<Var*, Var*> var_map;
};
} // namespace tensorexpr

View File

@ -28,91 +28,91 @@ bool SimplifierHashType::operator!=(const size_t other) const {
return _h != other;
}
void HashProvider::visit(const Add* v) {
void HashProvider::visit(Add* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "+", hashOf(v->rhs())));
}
void HashProvider::visit(const Sub* v) {
void HashProvider::visit(Sub* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "-", hashOf(v->rhs())));
}
void HashProvider::visit(const Mul* v) {
void HashProvider::visit(Mul* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "*", hashOf(v->rhs())));
}
void HashProvider::visit(const Div* v) {
void HashProvider::visit(Div* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "/", hashOf(v->rhs())));
}
void HashProvider::visit(const Mod* v) {
void HashProvider::visit(Mod* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs())));
}
void HashProvider::visit(const Max* v) {
void HashProvider::visit(Max* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "Mx", hashOf(v->rhs())));
}
void HashProvider::visit(const Min* v) {
void HashProvider::visit(Min* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "Mn", hashOf(v->rhs())));
}
void HashProvider::visit(const And* v) {
void HashProvider::visit(And* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "&", hashOf(v->rhs())));
}
void HashProvider::visit(const Or* v) {
void HashProvider::visit(Or* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "|", hashOf(v->rhs())));
}
void HashProvider::visit(const Xor* v) {
void HashProvider::visit(Xor* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "^", hashOf(v->rhs())));
}
void HashProvider::visit(const Lshift* v) {
void HashProvider::visit(Lshift* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "<<", hashOf(v->rhs())));
}
void HashProvider::visit(const Rshift* v) {
void HashProvider::visit(Rshift* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), ">>", hashOf(v->rhs())));
}
void HashProvider::visit(const CompareSelect* v) {
void HashProvider::visit(CompareSelect* v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
@ -128,18 +128,18 @@ void HashProvider::visit(const CompareSelect* v) {
hashOf(v->ret_val2())));
}
void HashProvider::visit(const Cast* v) {
void HashProvider::visit(Cast* v) {
CACHE_GUARD();
v->src_value()->accept(this);
putHash(v, hash_combine("cast", v->dtype(), hashOf(v->src_value())));
}
void HashProvider::visit(const Var* v) {
void HashProvider::visit(Var* v) {
CACHE_GUARD();
putHash(v, hash_combine("var", name_manager_.get_unique_name(v)));
}
void HashProvider::visit(const Ramp* v) {
void HashProvider::visit(Ramp* v) {
CACHE_GUARD();
v->base()->accept(this);
v->stride()->accept(this);
@ -148,22 +148,22 @@ void HashProvider::visit(const Ramp* v) {
hash_combine("ramp", hashOf(v->base()), hashOf(v->stride()), v->lanes()));
}
void HashProvider::visit(const Load* v) {
void HashProvider::visit(Load* v) {
CACHE_GUARD();
v->base_handle()->accept(this);
SimplifierHashType indices_hash;
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
ind->accept(this);
indices_hash = hash_combine(indices_hash, hashOf(ind));
}
putHash(v, hash_combine("load", hashOf(v->base_handle()), indices_hash));
}
void HashProvider::visit(const Store* v) {
void HashProvider::visit(Store* v) {
CACHE_GUARD();
v->base_handle()->accept(this);
SimplifierHashType indices_hash;
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
ind->accept(this);
indices_hash = hash_combine(indices_hash, hashOf(ind));
}
@ -174,7 +174,7 @@ void HashProvider::visit(const Store* v) {
"store", hashOf(v->base_handle()), indices_hash, hashOf(v->value())));
}
void HashProvider::visit(const Block* v) {
void HashProvider::visit(Block* v) {
CACHE_GUARD();
SimplifierHashType hash;
@ -185,7 +185,7 @@ void HashProvider::visit(const Block* v) {
putHash(v, hash);
}
void HashProvider::visit(const For* v) {
void HashProvider::visit(For* v) {
CACHE_GUARD();
v->var()->accept(this);
v->start()->accept(this);
@ -202,13 +202,13 @@ void HashProvider::visit(const For* v) {
putHash(v, hash);
}
void HashProvider::visit(const Broadcast* v) {
void HashProvider::visit(Broadcast* v) {
CACHE_GUARD();
v->value()->accept(this);
putHash(v, hash_combine("broadcast", hashOf(v->value()), v->lanes()));
}
void HashProvider::visit(const IfThenElse* v) {
void HashProvider::visit(IfThenElse* v) {
CACHE_GUARD();
v->condition()->accept(this);
v->true_value()->accept(this);
@ -223,7 +223,7 @@ void HashProvider::visit(const IfThenElse* v) {
hashOf(v->false_value())));
}
void HashProvider::visit(const Intrinsics* v) {
void HashProvider::visit(Intrinsics* v) {
CACHE_GUARD();
// calls to rand are not symbolic and have a different value each time, they
// should not hash to anything and this is the best we can do.
@ -234,7 +234,7 @@ void HashProvider::visit(const Intrinsics* v) {
}
SimplifierHashType hash(te_hash(v->func_name()));
for (const auto i : c10::irange(v->nparams())) {
for (auto i : c10::irange(v->nparams())) {
v->param(i)->accept(this);
hash = hash_combine(hash, hashOf(v->param(i)));
}
@ -242,33 +242,33 @@ void HashProvider::visit(const Intrinsics* v) {
putHash(v, hash);
}
void HashProvider::visit(const Allocate* v) {
void HashProvider::visit(Allocate* v) {
CACHE_GUARD();
const Var* buffer_var = v->buffer_var();
Var* buffer_var = v->buffer_var();
buffer_var->accept(this);
SimplifierHashType hash =
hash_combine("allocate", hashOf(buffer_var), v->dtype());
std::vector<const Expr*> dims = v->dims();
for (const Expr* dim : dims) {
std::vector<Expr*> dims = v->dims();
for (Expr* dim : dims) {
dim->accept(this);
hash = hash_combine(hash, hashOf(dim));
}
putHash(v, hash);
}
void HashProvider::visit(const Free* v) {
void HashProvider::visit(Free* v) {
CACHE_GUARD();
const Var* buffer_var = v->buffer_var();
Var* buffer_var = v->buffer_var();
buffer_var->accept(this);
putHash(v, hash_combine("free", hashOf(buffer_var)));
}
void HashProvider::visit(const Cond* v) {
void HashProvider::visit(Cond* v) {
CACHE_GUARD();
const Expr* condition = v->condition();
Expr* condition = v->condition();
Stmt* true_stmt = v->true_stmt();
Stmt* false_stmt = v->false_stmt();
condition->accept(this);
@ -286,7 +286,7 @@ void HashProvider::visit(const Cond* v) {
putHash(v, hash);
}
void HashProvider::visit(const Term* v) {
void HashProvider::visit(Term* v) {
CACHE_GUARD();
v->scalar()->accept(this);
@ -299,7 +299,7 @@ void HashProvider::visit(const Term* v) {
putHash(v, hash);
}
void HashProvider::visit(const Polynomial* v) {
void HashProvider::visit(Polynomial* v) {
CACHE_GUARD();
v->scalar()->accept(this);
@ -312,7 +312,7 @@ void HashProvider::visit(const Polynomial* v) {
putHash(v, hash);
}
void HashProvider::visit(const MaxTerm* v) {
void HashProvider::visit(MaxTerm* v) {
CACHE_GUARD();
SimplifierHashType hash = hash_combine("maxterm");
if (v->scalar()) {
@ -328,7 +328,7 @@ void HashProvider::visit(const MaxTerm* v) {
putHash(v, hash);
}
void HashProvider::visit(const MinTerm* v) {
void HashProvider::visit(MinTerm* v) {
CACHE_GUARD();
SimplifierHashType hash = hash_combine("minterm");
if (v->scalar()) {

View File

@ -53,7 +53,7 @@ class Polynomial;
class TORCH_API HashProvider : public IRVisitor {
public:
template <class T>
SimplifierHashType hash(const T* e) {
SimplifierHashType hash(T* e) {
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
e->accept(this);
return hashOf(e);
@ -67,19 +67,19 @@ class TORCH_API HashProvider : public IRVisitor {
exprToHash_.clear();
}
void visit(const Add* v) override;
void visit(const Sub* v) override;
void visit(const Mul* v) override;
void visit(const Div* v) override;
void visit(const Mod* v) override;
void visit(const Max* v) override;
void visit(const Min* v) override;
void visit(const And* v) override;
void visit(const Or* v) override;
void visit(const Xor* v) override;
void visit(const Lshift* v) override;
void visit(const Rshift* v) override;
void visit(const CompareSelect* v) override;
void visit(Add* v) override;
void visit(Sub* v) override;
void visit(Mul* v) override;
void visit(Div* v) override;
void visit(Mod* v) override;
void visit(Max* v) override;
void visit(Min* v) override;
void visit(And* v) override;
void visit(Or* v) override;
void visit(Xor* v) override;
void visit(Lshift* v) override;
void visit(Rshift* v) override;
void visit(CompareSelect* v) override;
// NOLINTNEXTLINE
#define IMM_VISIT(Type, Name) \
@ -90,23 +90,23 @@ class TORCH_API HashProvider : public IRVisitor {
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
#undef IMM_VISIT
void visit(const Cast* v) override;
void visit(const Var* v) override;
void visit(const Ramp* v) override;
void visit(const Load* v) override;
void visit(const Store* v) override;
void visit(const Block* v) override;
void visit(const For* v) override;
void visit(const Broadcast* v) override;
void visit(const IfThenElse* v) override;
void visit(const Intrinsics* v) override;
void visit(const Allocate* v) override;
void visit(const Free* v) override;
void visit(const Cond* v) override;
void visit(const Term* v) override;
void visit(const Polynomial* v) override;
void visit(const MaxTerm* v) override;
void visit(const MinTerm* v) override;
void visit(Cast* v) override;
void visit(Var* v) override;
void visit(Ramp* v) override;
void visit(Load* v) override;
void visit(Store* v) override;
void visit(Block* v) override;
void visit(For* v) override;
void visit(Broadcast* v) override;
void visit(IfThenElse* v) override;
void visit(Intrinsics* v) override;
void visit(Allocate* v) override;
void visit(Free* v) override;
void visit(Cond* v) override;
void visit(Term* v) override;
void visit(Polynomial* v) override;
void visit(MaxTerm* v) override;
void visit(MinTerm* v) override;
template <typename... Types>
SimplifierHashType hash_combine(const Types&... args) {
@ -116,7 +116,7 @@ class TORCH_API HashProvider : public IRVisitor {
}
private:
SimplifierHashType hashOf(const Expr* e) {
SimplifierHashType hashOf(Expr* e) {
auto it = exprToHash_.find(e);
if (it != exprToHash_.end()) {
return it->second;
@ -132,7 +132,7 @@ class TORCH_API HashProvider : public IRVisitor {
return hash;
}
SimplifierHashType hashOf(const Stmt* s) {
SimplifierHashType hashOf(Stmt* s) {
auto it = exprToHash_.find(s);
if (it != exprToHash_.end()) {
return it->second;
@ -169,7 +169,7 @@ class TORCH_API HashProvider : public IRVisitor {
(seed._h >> 4);
}
void _hash_combine(SimplifierHashType& seed, const Expr* e) {
void _hash_combine(SimplifierHashType& seed, Expr* e) {
_hash_combine(seed, hash(e));
}

View File

@ -12,7 +12,7 @@ static Dtype ChooseDtype(const Dtype& buffer_dtype, const Dtype& index_dtype) {
return Dtype(buffer_dtype, index_dtype.lanes());
}
static Dtype dtypeOfIndices(const std::vector<const Expr*>& indices) {
static Dtype dtypeOfIndices(const std::vector<Expr*>& indices) {
if (!indices.size()) {
// Return something so we can handle scalar buffers.
return kInt;
@ -20,7 +20,7 @@ static Dtype dtypeOfIndices(const std::vector<const Expr*>& indices) {
return indices.at(0)->dtype();
}
void castIndicesToInts(std::vector<const Expr*>& indices) {
void castIndicesToInts(std::vector<Expr*>& indices) {
// Cast all indices to either Int or Long
auto index_dtype = ScalarType::Int;
for (auto& index : indices) {
@ -40,12 +40,12 @@ void castIndicesToInts(std::vector<const Expr*>& indices) {
}
}
Load::Load(Dtype dtype, const Buf* buf, std::vector<const Expr*> indices)
Load::Load(Dtype dtype, Buf* buf, std::vector<Expr*> indices)
: ExprNodeBase(dtype), buf_(buf), indices_(std::move(indices)) {
castIndicesToInts(indices_);
}
Load::Load(const Buf* buf, const std::vector<const Expr*>& indices)
Load::Load(Buf* buf, const std::vector<Expr*>& indices)
: Load(ChooseDtype(buf->dtype(), dtypeOfIndices(indices)), buf, indices) {}
ExprHandle Load::make(
@ -62,10 +62,7 @@ ExprHandle Load::make(
return Load::make(buf.dtype(), buf, indices);
}
Store::Store(
const Buf* buf,
std::vector<const Expr*> indices,
const Expr* value)
Store::Store(Buf* buf, std::vector<Expr*> indices, Expr* value)
: buf_(buf), indices_(std::move(indices)), value_(value) {
castIndicesToInts(indices_);
}
@ -78,9 +75,9 @@ Store* Store::make(
buf.node(), ExprHandleVectorToExprVector(indices), value.node());
}
const Expr* flatten_index(
const std::vector<const Expr*>& dims,
const std::vector<const Expr*>& indices) {
Expr* flatten_index(
const std::vector<Expr*>& dims,
const std::vector<Expr*>& indices) {
// Handle already flattened indices first
if (indices.size() == 1) {
return indices[0];
@ -93,7 +90,7 @@ const Expr* flatten_index(
if (ndim == 0) {
return new IntImm(0);
}
std::vector<const Expr*> strides(ndim);
std::vector<Expr*> strides(ndim);
// stride[i] = stride[i+1]*dims[i+1], i < ndim-1
// stride[i] = 1, i = ndim-1
strides[ndim - 1] = new IntImm(1);
@ -101,8 +98,8 @@ const Expr* flatten_index(
strides[ndim - 1 - i] = new Mul(strides[ndim - i], dims[ndim - i]);
}
const Expr* total_index = new IntImm(0);
for (const auto i : c10::irange(ndim)) {
Expr* total_index = new IntImm(0);
for (auto i : c10::irange(ndim)) {
total_index = new Add(total_index, new Mul(indices[i], strides[i]));
}
return total_index;
@ -123,7 +120,7 @@ Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2) {
Dtype Intrinsics::IntrinsicsDtype(
IntrinsicsOp op_type,
const std::vector<const Expr*>& params) {
const std::vector<Expr*>& params) {
// TODO: check the op_type and make a real decision
// Doesnt this fail with kRand?
if (params.size() == 0) {
@ -184,7 +181,7 @@ ExternalCall* ExternalCall::make(
const std::string& func_name,
const std::vector<BufHandle>& buf_args,
const std::vector<ExprHandle>& args) {
std::vector<const Buf*> buf_arg_nodes;
std::vector<Buf*> buf_arg_nodes;
buf_arg_nodes.reserve(buf_args.size());
for (const BufHandle& buf_arg : buf_args) {
buf_arg_nodes.push_back(buf_arg.node());
@ -193,37 +190,35 @@ ExternalCall* ExternalCall::make(
buf.node(), func_name, buf_arg_nodes, ExprHandleVectorToExprVector(args));
}
std::vector<const Expr*> ExprHandleVectorToExprVector(
std::vector<Expr*> ExprHandleVectorToExprVector(
const std::vector<ExprHandle>& v) {
std::vector<const Expr*> result(v.size());
for (const auto i : c10::irange(v.size())) {
std::vector<Expr*> result(v.size());
for (auto i : c10::irange(v.size())) {
result[i] = v[i].node();
}
return result;
}
std::vector<ExprHandle> ExprVectorToExprHandleVector(
const std::vector<const Expr*>& v) {
const std::vector<Expr*>& v) {
std::vector<ExprHandle> result(v.size());
for (const auto i : c10::irange(v.size())) {
for (auto i : c10::irange(v.size())) {
result[i] = ExprHandle(v[i]);
}
return result;
}
std::vector<const Var*> VarHandleVectorToVarVector(
const std::vector<VarHandle>& v) {
std::vector<const Var*> result(v.size());
for (const auto i : c10::irange(v.size())) {
std::vector<Var*> VarHandleVectorToVarVector(const std::vector<VarHandle>& v) {
std::vector<Var*> result(v.size());
for (auto i : c10::irange(v.size())) {
result[i] = v[i].node();
}
return result;
}
std::vector<VarHandle> VarVectorToVarHandleVector(
const std::vector<const Var*>& v) {
std::vector<VarHandle> VarVectorToVarHandleVector(const std::vector<Var*>& v) {
std::vector<VarHandle> result(v.size());
for (const auto i : c10::irange(v.size())) {
for (auto i : c10::irange(v.size())) {
result[i] = VarHandle(v[i]);
}
return result;

View File

@ -68,13 +68,13 @@ class Placeholder;
class TORCH_API Cast : public ExprNode<Cast> {
public:
const Expr* src_value() const {
Expr* src_value() const {
return src_value_;
}
static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
return ExprHandle(new Cast(dtype, src_value.node()));
}
Cast(Dtype dtype, const Expr* src_value)
Cast(Dtype dtype, Expr* src_value)
: ExprNodeBase(dtype, kCast), src_value_(src_value) {}
bool isConstant() const override {
@ -82,7 +82,7 @@ class TORCH_API Cast : public ExprNode<Cast> {
}
private:
const Expr* src_value_;
Expr* src_value_;
};
template <typename T>
@ -93,13 +93,13 @@ ExprHandle cast(const ExprHandle& src_value) {
// This is a bitwise cast, akin to bitcast in LLVM
class TORCH_API BitCast : public ExprNode<BitCast> {
public:
const Expr* src_value() const {
Expr* src_value() const {
return src_value_;
}
static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
return ExprHandle(new BitCast(dtype, src_value.node()));
}
BitCast(Dtype dtype, const Expr* src_value)
BitCast(Dtype dtype, Expr* src_value)
: ExprNodeBase(dtype, kBitCast), src_value_(src_value) {
TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size());
}
@ -109,7 +109,7 @@ class TORCH_API BitCast : public ExprNode<BitCast> {
}
private:
const Expr* src_value_;
Expr* src_value_;
};
template <typename T>
@ -123,10 +123,10 @@ ExprHandle bitcast(const ExprHandle& src_value) {
template <typename Op>
class BinaryOpNode : public ExprNode<Op> {
public:
const Expr* lhs() const {
Expr* lhs() const {
return this->lhs_;
}
const Expr* rhs() const {
Expr* rhs() const {
return this->rhs_;
}
@ -136,8 +136,8 @@ class BinaryOpNode : public ExprNode<Op> {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
BinaryOpNode(
const Expr* lhs_v,
const Expr* rhs_v,
Expr* lhs_v,
Expr* rhs_v,
IRNodeType expr_type,
ScalarType ret_type = ScalarType::Undefined)
: ExprNode<Op>(
@ -148,51 +148,46 @@ class BinaryOpNode : public ExprNode<Op> {
rhs_(CastIfNeeded(rhs_v, ExprNode<Op>::dtype())) {}
private:
static const Expr* CastIfNeeded(const Expr* expr, Dtype dst_dtype) {
static Expr* CastIfNeeded(Expr* expr, Dtype dst_dtype) {
if (expr->dtype() == dst_dtype) {
return expr;
}
return Cast::make(dst_dtype, ExprHandle(expr)).node();
}
const Expr* lhs_;
const Expr* rhs_;
Expr* lhs_;
Expr* rhs_;
};
class TORCH_API Add : public BinaryOpNode<Add> {
public:
Add(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
Add(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
};
class TORCH_API Sub : public BinaryOpNode<Sub> {
public:
Sub(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kSub) {}
Sub(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kSub) {}
};
class TORCH_API Mul : public BinaryOpNode<Mul> {
public:
Mul(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kMul) {}
Mul(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMul) {}
};
class TORCH_API Div : public BinaryOpNode<Div> {
public:
Div(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {}
Div(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {}
};
class TORCH_API Mod : public BinaryOpNode<Mod> {
public:
Mod(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
Mod(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
};
template <typename Op>
class BitwiseOpNode : public BinaryOpNode<Op> {
public:
BitwiseOpNode(const Expr* lhs, const Expr* rhs, IRNodeType type)
BitwiseOpNode(Expr* lhs, Expr* rhs, IRNodeType type)
: BinaryOpNode<Op>(lhs, rhs, type) {}
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
@ -208,32 +203,27 @@ class BitwiseOpNode : public BinaryOpNode<Op> {
class TORCH_API And : public BitwiseOpNode<And> {
public:
And(const Expr* lhs, const Expr* rhs)
: BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
And(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
};
class TORCH_API Or : public BitwiseOpNode<Or> {
public:
Or(const Expr* lhs, const Expr* rhs)
: BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
Or(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
};
class TORCH_API Xor : public BitwiseOpNode<Xor> {
public:
Xor(const Expr* lhs, const Expr* rhs)
: BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
Xor(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
};
class TORCH_API Lshift : public BitwiseOpNode<Lshift> {
public:
Lshift(const Expr* lhs, const Expr* rhs)
: BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
Lshift(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
};
class TORCH_API Rshift : public BitwiseOpNode<Rshift> {
public:
Rshift(const Expr* lhs, const Expr* rhs)
: BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
Rshift(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
};
// TODO: add TORCH_API
@ -243,7 +233,7 @@ class Max : public BinaryOpNode<Max> {
bool propagate_nans_;
public:
Max(const Expr* lhs, const Expr* rhs, bool propagate_nans)
Max(Expr* lhs, Expr* rhs, bool propagate_nans)
: BinaryOpNode(lhs, rhs, IRNodeType::kMax),
propagate_nans_(propagate_nans) {}
@ -267,7 +257,7 @@ class Min : public BinaryOpNode<Min> {
bool propagate_nans_;
public:
Min(const Expr* lhs, const Expr* rhs, bool propagate_nans)
Min(Expr* lhs, Expr* rhs, bool propagate_nans)
: BinaryOpNode(lhs, rhs, IRNodeType::kMin),
propagate_nans_(propagate_nans) {}
@ -328,7 +318,7 @@ Expr* getImmediateByType(Dtype dtype, T initialVal) {
}
template <typename T>
T immediateAs(const Expr* e) {
T immediateAs(Expr* e) {
#define TYPE_CASE(Type, Name) \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
return imm->value(); \
@ -345,7 +335,7 @@ T immediateAs(ExprHandle e) {
}
template <typename T>
bool immediateEquals(const Expr* e, T val) {
bool immediateEquals(Expr* e, T val) {
#define TYPE_CASE(Type, Name) \
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
return imm->value() == val; \
@ -371,10 +361,10 @@ bool immediateIsNegative(const T* e) {
// [base, base + 1 * stride, ... , base + (lanes - 1) * stride]
class TORCH_API Ramp : public ExprNode<Ramp> {
public:
const Expr* base() const {
Expr* base() const {
return base_;
}
const Expr* stride() const {
Expr* stride() const {
return stride_;
}
static ExprHandle make(
@ -390,31 +380,31 @@ class TORCH_API Ramp : public ExprNode<Ramp> {
return lanes_;
}
Ramp(const Expr* base, const Expr* stride, int lanes)
Ramp(Expr* base, Expr* stride, int lanes)
: ExprNodeBase(Dtype(base->dtype(), lanes)),
base_(base),
stride_(stride),
lanes_(lanes) {}
private:
const Expr* base_;
const Expr* stride_;
Expr* base_;
Expr* stride_;
int lanes_;
};
class TORCH_API Load : public ExprNode<Load> {
public:
const Var* base_handle() const {
Var* base_handle() const {
return buf_->base_handle();
}
std::vector<const Expr*> indices() const {
std::vector<Expr*> indices() const {
return indices_;
}
const Expr* flat_index() const {
Expr* flat_index() const {
TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
return indices_[0];
}
const Buf* buf() const {
Buf* buf() const {
return buf_;
}
static ExprHandle make(
@ -425,21 +415,21 @@ class TORCH_API Load : public ExprNode<Load> {
const BufHandle& buf,
const std::vector<ExprHandle>& indices);
Load(Dtype dtype, const Buf* base_handle, std::vector<const Expr*> indices);
Load(const Buf* base_handle, const std::vector<const Expr*>& indices);
Load(Dtype dtype, Buf* base_handle, std::vector<Expr*> indices);
Load(Buf* base_handle, const std::vector<Expr*>& indices);
void set_indices(std::vector<const Expr*> indices) {
void set_indices(std::vector<Expr*> indices) {
indices_ = indices;
};
private:
const Buf* buf_;
std::vector<const Expr*> indices_;
Buf* buf_;
std::vector<Expr*> indices_;
};
class TORCH_API Broadcast : public ExprNode<Broadcast> {
public:
const Expr* value() const {
Expr* value() const {
return value_;
}
int lanes() const {
@ -448,29 +438,29 @@ class TORCH_API Broadcast : public ExprNode<Broadcast> {
static ExprHandle make(const ExprHandle& value, int lanes) {
return ExprHandle(new Broadcast(value.node(), lanes));
}
Broadcast(const Expr* value, int lanes)
Broadcast(Expr* value, int lanes)
: ExprNodeBase(Dtype(value->dtype(), lanes)),
value_(value),
lanes_(lanes) {}
private:
const Expr* value_;
Expr* value_;
int lanes_;
};
class TORCH_API IfThenElse : public ExprNode<IfThenElse> {
public:
const Expr* condition() const {
Expr* condition() const {
return condition_;
}
// Lazily evaluated only if condition is true
const Expr* true_value() const {
Expr* true_value() const {
return true_;
}
// Lazily evaluated only if condition is false
const Expr* false_value() const {
Expr* false_value() const {
return false_;
}
@ -490,13 +480,13 @@ class TORCH_API IfThenElse : public ExprNode<IfThenElse> {
return ExprHandle(new IfThenElse(c.node(), t.node(), f.node()));
}
IfThenElse(const Expr* c, const Expr* t, const Expr* f)
IfThenElse(Expr* c, Expr* t, Expr* f)
: ExprNodeBase(t->dtype()), condition_(c), true_(t), false_(f) {}
private:
const Expr* condition_;
const Expr* true_;
const Expr* false_;
Expr* condition_;
Expr* true_;
Expr* false_;
};
class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
@ -504,16 +494,16 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
CompareSelectOperation compare_select_op() const {
return compare_op_;
}
const Expr* lhs() const {
Expr* lhs() const {
return this->lhs_;
}
const Expr* rhs() const {
Expr* rhs() const {
return this->rhs_;
}
const Expr* ret_val1() const {
Expr* ret_val1() const {
return this->ret_val1_;
}
const Expr* ret_val2() const {
Expr* ret_val2() const {
return this->ret_val2_;
}
CompareSelectBias bias() const {
@ -557,10 +547,10 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
}
CompareSelect(
const Expr* lhs,
const Expr* rhs,
const Expr* ret_val1,
const Expr* ret_val2,
Expr* lhs,
Expr* rhs,
Expr* ret_val1,
Expr* ret_val2,
CompareSelectOperation cmp_op,
CompareSelectBias bias = kUnbiased)
: ExprNodeBase(ret_val1->dtype()),
@ -573,8 +563,8 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CompareSelect(
const Expr* lhs,
const Expr* rhs,
Expr* lhs,
Expr* rhs,
CompareSelectOperation cmp_op,
CompareSelectBias bias = kUnbiased)
: ExprNodeBase(kInt),
@ -586,10 +576,10 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
bias_(bias) {}
private:
const Expr* lhs_;
const Expr* rhs_;
const Expr* ret_val1_;
const Expr* ret_val2_;
Expr* lhs_;
Expr* rhs_;
Expr* ret_val1_;
Expr* ret_val2_;
CompareSelectOperation compare_op_;
CompareSelectBias bias_;
};
@ -647,7 +637,7 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
IntrinsicsOp op_type,
const std::vector<ExprHandle>& params) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Expr*> params_nodes(params.size());
std::vector<Expr*> params_nodes(params.size());
for (size_t i = 0; i < params.size(); i++) {
params_nodes[i] = params[i].node();
}
@ -747,7 +737,7 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Intrinsics(IntrinsicsOp op_type, const Expr* v1)
Intrinsics(IntrinsicsOp op_type, Expr* v1)
: ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype())),
params_({v1}),
op_type_(op_type) {
@ -757,7 +747,7 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Intrinsics(IntrinsicsOp op_type, const Expr* v1, const Expr* v2)
Intrinsics(IntrinsicsOp op_type, Expr* v1, Expr* v2)
: ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype(), v2->dtype())),
params_({v1, v2}),
op_type_(op_type) {
@ -767,7 +757,7 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Intrinsics(IntrinsicsOp op_type, const std::vector<const Expr*>& params)
Intrinsics(IntrinsicsOp op_type, const std::vector<Expr*>& params)
: ExprNodeBase(IntrinsicsDtype(op_type, params)),
params_(params),
op_type_(op_type) {
@ -784,10 +774,10 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
return params_.size();
}
const Expr* param(int index) const {
Expr* param(int index) const {
return params_[index];
}
const std::vector<const Expr*>& params() const {
const std::vector<Expr*>& params() const {
return params_;
}
@ -797,9 +787,9 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2);
static Dtype IntrinsicsDtype(
IntrinsicsOp op_type,
const std::vector<const Expr*>& params);
const std::vector<Expr*>& params);
std::vector<const Expr*> params_;
std::vector<Expr*> params_;
IntrinsicsOp op_type_;
};
@ -808,17 +798,17 @@ class Term;
class MaxTerm;
class MinTerm;
TORCH_API std::vector<const Expr*> ExprHandleVectorToExprVector(
TORCH_API std::vector<Expr*> ExprHandleVectorToExprVector(
const std::vector<ExprHandle>&);
TORCH_API std::vector<ExprHandle> ExprVectorToExprHandleVector(
const std::vector<const Expr*>&);
TORCH_API std::vector<const Var*> VarHandleVectorToVarVector(
const std::vector<Expr*>&);
TORCH_API std::vector<Var*> VarHandleVectorToVarVector(
const std::vector<VarHandle>&);
TORCH_API std::vector<VarHandle> VarVectorToVarHandleVector(
const std::vector<const Var*>&);
TORCH_API const Expr* flatten_index(
const std::vector<const Expr*>& dims,
const std::vector<const Expr*>& indices);
const std::vector<Var*>&);
TORCH_API Expr* flatten_index(
const std::vector<Expr*>& dims,
const std::vector<Expr*>& indices);
} // namespace tensorexpr
} // namespace jit

View File

@ -12,14 +12,14 @@ namespace jit {
namespace tensorexpr {
template <typename Op>
static const Expr* mutate_binary_op(
const BinaryOpNode<Op>* v,
static Expr* mutate_binary_op(
BinaryOpNode<Op>* v,
IRMutator* mutator,
bool option = false) {
const Expr* lhs = v->lhs();
const Expr* rhs = v->rhs();
const Expr* lhs_new = lhs->accept_mutator(mutator);
const Expr* rhs_new = rhs->accept_mutator(mutator);
Expr* lhs = v->lhs();
Expr* rhs = v->rhs();
Expr* lhs_new = lhs->accept_mutator(mutator);
Expr* rhs_new = rhs->accept_mutator(mutator);
if (lhs == lhs_new && rhs == rhs_new) {
return v;
}
@ -54,63 +54,63 @@ static const Expr* mutate_binary_op(
}
}
const Expr* IRMutator::mutate(const Add* v) {
Expr* IRMutator::mutate(Add* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const Sub* v) {
Expr* IRMutator::mutate(Sub* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const Mul* v) {
Expr* IRMutator::mutate(Mul* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const Div* v) {
Expr* IRMutator::mutate(Div* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const Mod* v) {
Expr* IRMutator::mutate(Mod* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const And* v) {
Expr* IRMutator::mutate(And* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const Or* v) {
Expr* IRMutator::mutate(Or* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const Xor* v) {
Expr* IRMutator::mutate(Xor* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const Lshift* v) {
Expr* IRMutator::mutate(Lshift* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const Rshift* v) {
Expr* IRMutator::mutate(Rshift* v) {
return mutate_binary_op(v, this);
}
const Expr* IRMutator::mutate(const Max* v) {
Expr* IRMutator::mutate(Max* v) {
return mutate_binary_op(v, this, v->propagate_nans());
}
const Expr* IRMutator::mutate(const Min* v) {
Expr* IRMutator::mutate(Min* v) {
return mutate_binary_op(v, this, v->propagate_nans());
}
const Expr* IRMutator::mutate(const CompareSelect* v) {
const Expr* lhs = v->lhs();
const Expr* rhs = v->rhs();
const Expr* retval1 = v->ret_val1();
const Expr* retval2 = v->ret_val2();
const Expr* lhs_new = lhs->accept_mutator(this);
const Expr* rhs_new = rhs->accept_mutator(this);
const Expr* retval1_new = retval1->accept_mutator(this);
const Expr* retval2_new = retval2->accept_mutator(this);
Expr* IRMutator::mutate(CompareSelect* v) {
Expr* lhs = v->lhs();
Expr* rhs = v->rhs();
Expr* retval1 = v->ret_val1();
Expr* retval2 = v->ret_val2();
Expr* lhs_new = lhs->accept_mutator(this);
Expr* rhs_new = rhs->accept_mutator(this);
Expr* retval1_new = retval1->accept_mutator(this);
Expr* retval2_new = retval2->accept_mutator(this);
if (lhs == lhs_new && rhs == rhs_new && retval1 == retval1_new &&
retval2 == retval2_new) {
return v;
@ -126,68 +126,68 @@ const Expr* IRMutator::mutate(const CompareSelect* v) {
}
// NOLINTNEXTLINE
#define IMM_MUTATE_DEFINE(_1, Name) \
const Expr* IRMutator::mutate(const Name##Imm* v) { \
return v; \
#define IMM_MUTATE_DEFINE(_1, Name) \
Expr* IRMutator::mutate(Name##Imm* v) { \
return v; \
}
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE);
#undef IMM_MUTATE_DEFINE
const Expr* IRMutator::mutate(const Cast* v) {
const Expr* src_value = v->src_value();
const Expr* src_value_new = src_value->accept_mutator(this);
Expr* IRMutator::mutate(Cast* v) {
Expr* src_value = v->src_value();
Expr* src_value_new = src_value->accept_mutator(this);
if (src_value_new == v->src_value()) {
return v;
}
return new Cast(v->dtype(), src_value_new);
}
const Expr* IRMutator::mutate(const BitCast* v) {
const Expr* src_value = v->src_value();
const Expr* src_value_new = src_value->accept_mutator(this);
Expr* IRMutator::mutate(BitCast* v) {
Expr* src_value = v->src_value();
Expr* src_value_new = src_value->accept_mutator(this);
if (src_value_new == v->src_value()) {
return v;
}
return new BitCast(v->dtype(), src_value_new);
}
const Expr* IRMutator::mutate(const Var* v) {
Expr* IRMutator::mutate(Var* v) {
return v;
}
const Expr* IRMutator::mutate(const Ramp* v) {
const Expr* base = v->base();
const Expr* stride = v->stride();
const Expr* base_new = base->accept_mutator(this);
const Expr* stride_new = stride->accept_mutator(this);
Expr* IRMutator::mutate(Ramp* v) {
Expr* base = v->base();
Expr* stride = v->stride();
Expr* base_new = base->accept_mutator(this);
Expr* stride_new = stride->accept_mutator(this);
if (base == base_new && stride == stride_new) {
return v;
}
return new Ramp(base_new, stride_new, v->lanes());
}
const Expr* IRMutator::mutate(const Load* v) {
Expr* IRMutator::mutate(Load* v) {
Dtype dtype = v->dtype();
const Buf* buf = v->buf();
Buf* buf = v->buf();
bool any_index_changed = false;
std::vector<const Expr*> indices_new;
for (const Expr* ind : v->indices()) {
const Expr* new_ind = ind->accept_mutator(this);
std::vector<Expr*> indices_new;
for (Expr* ind : v->indices()) {
Expr* new_ind = ind->accept_mutator(this);
if (new_ind != ind) {
any_index_changed = true;
}
indices_new.push_back(new_ind);
}
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
if (buf == buf_new && !any_index_changed) {
return v;
}
return new Load(dtype, buf_new, indices_new);
}
const Expr* IRMutator::mutate(Buf* v) {
const Var* var = v->base_handle();
Expr* IRMutator::mutate(Buf* v) {
Var* var = v->base_handle();
Var* var_new =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
dynamic_cast<Var*>(const_cast<Expr*>(var->accept_mutator(this)));
@ -196,9 +196,9 @@ const Expr* IRMutator::mutate(Buf* v) {
}
bool any_change = var_new != var;
std::vector<const Expr*> dims_old = v->dims();
std::vector<const Expr*> dims_new(dims_old.size());
for (const auto i : c10::irange(dims_old.size())) {
std::vector<Expr*> dims_old = v->dims();
std::vector<Expr*> dims_new(dims_old.size());
for (auto i : c10::irange(dims_old.size())) {
dims_new[i] = dims_old[i]->accept_mutator(this);
any_change |= (dims_new[i] != dims_old[i]);
}
@ -212,23 +212,23 @@ const Expr* IRMutator::mutate(Buf* v) {
return v;
}
const Expr* IRMutator::mutate(const Broadcast* v) {
const Expr* value = v->value();
Expr* IRMutator::mutate(Broadcast* v) {
Expr* value = v->value();
int lanes = v->lanes();
const Expr* value_new = value->accept_mutator(this);
Expr* value_new = value->accept_mutator(this);
if (value == value_new) {
return v;
}
return new Broadcast(value_new, lanes);
}
const Expr* IRMutator::mutate(const IfThenElse* v) {
const Expr* condition = v->condition();
const Expr* true_value = v->true_value();
const Expr* false_value = v->false_value();
const Expr* condition_new = condition->accept_mutator(this);
const Expr* true_value_new = true_value->accept_mutator(this);
const Expr* false_value_new = false_value->accept_mutator(this);
Expr* IRMutator::mutate(IfThenElse* v) {
Expr* condition = v->condition();
Expr* true_value = v->true_value();
Expr* false_value = v->false_value();
Expr* condition_new = condition->accept_mutator(this);
Expr* true_value_new = true_value->accept_mutator(this);
Expr* false_value_new = false_value->accept_mutator(this);
if (condition == condition_new && true_value == true_value_new &&
false_value == false_value_new) {
@ -238,12 +238,12 @@ const Expr* IRMutator::mutate(const IfThenElse* v) {
return new IfThenElse(condition_new, true_value_new, false_value_new);
}
const Expr* IRMutator::mutate(const Intrinsics* v) {
std::vector<const Expr*> params(v->nparams());
Expr* IRMutator::mutate(Intrinsics* v) {
std::vector<Expr*> params(v->nparams());
bool any_change = false;
for (int i = 0; i < v->nparams(); i++) {
const Expr* value = v->param(i);
const Expr* value_new = value->accept_mutator(this);
Expr* value = v->param(i);
Expr* value_new = value->accept_mutator(this);
if (value != value_new) {
any_change = true;
}
@ -255,78 +255,78 @@ const Expr* IRMutator::mutate(const Intrinsics* v) {
return new Intrinsics(v->op_type(), params);
}
const Expr* IRMutator::mutate(const Term* v) {
const Expr* newScalar = v->scalar()->accept_mutator(this);
Expr* IRMutator::mutate(Term* v) {
Expr* newScalar = v->scalar()->accept_mutator(this);
std::vector<const Expr*> variables;
for (const auto* t : v->variables()) {
std::vector<Expr*> variables;
for (auto* t : v->variables()) {
variables.push_back(t->accept_mutator(this));
}
return new Term(v->hasher(), newScalar, variables);
}
const Expr* IRMutator::mutate(const Polynomial* v) {
const Expr* newScalar = v->scalar()->accept_mutator(this);
Expr* IRMutator::mutate(Polynomial* v) {
Expr* newScalar = v->scalar()->accept_mutator(this);
std::vector<const Term*> variables;
for (const auto* t : v->variables()) {
variables.push_back(static_cast<const Term*>(t->accept_mutator(this)));
std::vector<Term*> variables;
for (auto* t : v->variables()) {
variables.push_back(static_cast<Term*>(t->accept_mutator(this)));
}
return new Polynomial(v->hasher(), newScalar, variables);
}
const Expr* IRMutator::mutate(const RoundOff* v) {
Expr* IRMutator::mutate(RoundOff* v) {
return new RoundOff(
v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this));
}
const Expr* IRMutator::mutate(const MaxTerm* v) {
const Expr* newScalar = nullptr;
Expr* IRMutator::mutate(MaxTerm* v) {
Expr* newScalar = nullptr;
if (v->scalar()) {
newScalar = v->scalar()->accept_mutator(this);
}
std::vector<const Expr*> variables;
for (const auto* t : v->variables()) {
std::vector<Expr*> variables;
for (auto* t : v->variables()) {
variables.push_back(t->accept_mutator(this));
}
return new MaxTerm(v->hasher(), newScalar, v->propagate_nans(), variables);
}
const Expr* IRMutator::mutate(const MinTerm* v) {
const Expr* newScalar = nullptr;
Expr* IRMutator::mutate(MinTerm* v) {
Expr* newScalar = nullptr;
if (v->scalar()) {
newScalar = v->scalar()->accept_mutator(this);
}
std::vector<const Expr*> variables;
for (const auto* t : v->variables()) {
std::vector<Expr*> variables;
for (auto* t : v->variables()) {
variables.push_back(t->accept_mutator(this));
}
return new MinTerm(v->hasher(), newScalar, v->propagate_nans(), variables);
}
const Expr* IRMutator::mutate(const ReduceOp* v) {
const Expr* body_new = v->body()->accept_mutator(this);
Expr* IRMutator::mutate(ReduceOp* v) {
Expr* body_new = v->body()->accept_mutator(this);
std::vector<const Var*> new_reduce_args;
std::vector<Var*> new_reduce_args;
for (auto* r : v->reduce_args()) {
new_reduce_args.push_back(static_cast<const Var*>(r->accept_mutator(this)));
new_reduce_args.push_back(static_cast<Var*>(r->accept_mutator(this)));
}
return new ReduceOp(body_new, new_reduce_args, v->reducer());
}
Stmt* IRMutator::mutate(const For* v) {
const Expr* var = v->var();
const Expr* start = v->start();
const Expr* stop = v->stop();
Stmt* IRMutator::mutate(For* v) {
Expr* var = v->var();
Expr* start = v->start();
Expr* stop = v->stop();
Stmt* body = v->body();
LoopOptions loop_options = v->loop_options();
const Expr* var_new_expr = var->accept_mutator(this);
const Var* var_new = dynamic_cast<const Var*>(var_new_expr);
const Expr* start_new = start->accept_mutator(this);
const Expr* stop_new = stop->accept_mutator(this);
Expr* var_new_expr = var->accept_mutator(this);
Var* var_new = dynamic_cast<Var*>(var_new_expr);
Expr* start_new = start->accept_mutator(this);
Expr* stop_new = stop->accept_mutator(this);
Stmt* body_new = body->accept_mutator(this);
if (!body_new) {
return nullptr;
@ -341,7 +341,7 @@ Stmt* IRMutator::mutate(const For* v) {
return new For(var_new, start_new, stop_new, body_new, loop_options);
}
Stmt* IRMutator::mutate(const Block* v) {
Stmt* IRMutator::mutate(Block* v) {
bool any_change = false;
std::vector<Stmt*> stmts;
@ -362,69 +362,68 @@ Stmt* IRMutator::mutate(const Block* v) {
return Block::make(stmts);
}
Stmt* IRMutator::mutate(const Store* v) {
const Buf* buf = v->buf();
Stmt* IRMutator::mutate(Store* v) {
Buf* buf = v->buf();
bool any_index_changed = false;
std::vector<const Expr*> indices_new;
for (const Expr* ind : v->indices()) {
const Expr* new_ind = ind->accept_mutator(this);
std::vector<Expr*> indices_new;
for (Expr* ind : v->indices()) {
Expr* new_ind = ind->accept_mutator(this);
if (new_ind != ind) {
any_index_changed = true;
}
indices_new.push_back(new_ind);
}
const Expr* value = v->value();
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
const Expr* value_new = value->accept_mutator(this);
Expr* value = v->value();
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
Expr* value_new = value->accept_mutator(this);
if (buf == buf_new && !any_index_changed && value == value_new) {
return (Stmt*)v;
}
return new Store(buf_new, indices_new, value_new);
}
Stmt* IRMutator::mutate(const AtomicAdd* v) {
const Buf* buf = v->buf();
Stmt* IRMutator::mutate(AtomicAdd* v) {
Buf* buf = v->buf();
bool any_index_changed = false;
std::vector<const Expr*> indices_new;
for (const Expr* ind : v->indices()) {
const Expr* new_ind = ind->accept_mutator(this);
std::vector<Expr*> indices_new;
for (Expr* ind : v->indices()) {
Expr* new_ind = ind->accept_mutator(this);
if (new_ind != ind) {
any_index_changed = true;
}
indices_new.push_back(new_ind);
}
const Expr* value = v->value();
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
const Expr* value_new = value->accept_mutator(this);
Expr* value = v->value();
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
Expr* value_new = value->accept_mutator(this);
if (buf == buf_new && !any_index_changed && value == value_new) {
return (Stmt*)v;
}
return new AtomicAdd(buf_new, indices_new, value_new);
}
Stmt* IRMutator::mutate(const SyncThreads* v) {
Stmt* IRMutator::mutate(SyncThreads* v) {
return new SyncThreads();
}
Stmt* IRMutator::mutate(const ExternalCall* v) {
Stmt* IRMutator::mutate(ExternalCall* v) {
bool changed = false;
const Buf* new_buf = dynamic_cast<const Buf*>(v->buf()->accept_mutator(this));
Buf* new_buf = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
TORCH_INTERNAL_ASSERT(new_buf);
changed |= new_buf != v->buf();
std::vector<const Buf*> new_buf_args;
for (const Buf* buf_arg : v->buf_args()) {
const Buf* new_buf_arg =
dynamic_cast<const Buf*>(buf_arg->accept_mutator(this));
std::vector<Buf*> new_buf_args;
for (Buf* buf_arg : v->buf_args()) {
Buf* new_buf_arg = dynamic_cast<Buf*>(buf_arg->accept_mutator(this));
TORCH_INTERNAL_ASSERT(new_buf_arg);
new_buf_args.push_back(new_buf_arg);
changed |= new_buf_arg != buf_arg;
}
std::vector<const Expr*> new_args;
for (const Expr* arg : v->args()) {
const Expr* new_arg = arg->accept_mutator(this);
std::vector<Expr*> new_args;
for (Expr* arg : v->args()) {
Expr* new_arg = arg->accept_mutator(this);
new_args.push_back(new_arg);
changed |= new_arg != arg;
}
@ -433,9 +432,9 @@ Stmt* IRMutator::mutate(const ExternalCall* v) {
: (Stmt*)v;
}
Stmt* IRMutator::mutate(const Allocate* v) {
const Buf* buf = v->buf();
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
Stmt* IRMutator::mutate(Allocate* v) {
Buf* buf = v->buf();
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
TORCH_INTERNAL_ASSERT(buf_new);
if (buf_new == buf) {
return (Stmt*)v;
@ -443,9 +442,9 @@ Stmt* IRMutator::mutate(const Allocate* v) {
return new Allocate(buf_new);
}
Stmt* IRMutator::mutate(const Free* v) {
const Buf* buf = v->buf();
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
Stmt* IRMutator::mutate(Free* v) {
Buf* buf = v->buf();
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
TORCH_INTERNAL_ASSERT(buf_new);
if (buf_new == buf) {
return (Stmt*)v;
@ -454,12 +453,12 @@ Stmt* IRMutator::mutate(const Free* v) {
return new Free(buf_new);
}
Stmt* IRMutator::mutate(const Let* v) {
const Var* var_old = v->var();
const Var* var_new = dynamic_cast<const Var*>(var_old->accept_mutator(this));
Stmt* IRMutator::mutate(Let* v) {
Var* var_old = v->var();
Var* var_new = dynamic_cast<Var*>(var_old->accept_mutator(this));
const Expr* val_old = v->value();
const Expr* val_new = val_old->accept_mutator(this);
Expr* val_old = v->value();
Expr* val_new = val_old->accept_mutator(this);
if (var_new == var_old && val_old == val_new) {
return (Stmt*)v;
@ -468,12 +467,12 @@ Stmt* IRMutator::mutate(const Let* v) {
return new Let(var_new, val_new);
}
Stmt* IRMutator::mutate(const Cond* v) {
const Expr* cond_old = v->condition();
Stmt* IRMutator::mutate(Cond* v) {
Expr* cond_old = v->condition();
Stmt* true_old = v->true_stmt();
Stmt* false_old = v->false_stmt();
const Expr* cond_new = cond_old->accept_mutator(this);
Expr* cond_new = cond_old->accept_mutator(this);
Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old;
Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old;
@ -493,24 +492,24 @@ Stmt* IRMutator::mutate(const Cond* v) {
class StmtClone : public IRMutator {
public:
Stmt* mutate(const For* v) override;
Stmt* mutate(const Block* v) override;
Stmt* mutate(const Store* v) override;
Stmt* mutate(const Allocate* v) override;
Stmt* mutate(const Free* v) override;
Stmt* mutate(const Let* v) override;
Stmt* mutate(const Cond* v) override;
Stmt* mutate(const AtomicAdd* v) override;
Stmt* mutate(For* v) override;
Stmt* mutate(Block* v) override;
Stmt* mutate(Store* v) override;
Stmt* mutate(Allocate* v) override;
Stmt* mutate(Free* v) override;
Stmt* mutate(Let* v) override;
Stmt* mutate(Cond* v) override;
Stmt* mutate(AtomicAdd* v) override;
};
Stmt* StmtClone::mutate(const For* v) {
Stmt* StmtClone::mutate(For* v) {
// Only body needs to be cloned as only statements are mutable
Stmt* body_new = v->body()->accept_mutator(this);
return new For(v->var(), v->start(), v->stop(), body_new, v->loop_options());
}
Stmt* StmtClone::mutate(const Block* v) {
Stmt* StmtClone::mutate(Block* v) {
std::vector<Stmt*> stmts;
for (Stmt* stmt : *v) {
stmts.push_back(stmt->accept_mutator(this));
@ -518,27 +517,27 @@ Stmt* StmtClone::mutate(const Block* v) {
return new Block(stmts);
}
Stmt* StmtClone::mutate(const Store* v) {
Stmt* StmtClone::mutate(Store* v) {
return new Store(v->buf(), v->indices(), v->value());
}
Stmt* StmtClone::mutate(const AtomicAdd* v) {
Stmt* StmtClone::mutate(AtomicAdd* v) {
return new AtomicAdd(v->buf(), v->indices(), v->value());
}
Stmt* StmtClone::mutate(const Allocate* v) {
Stmt* StmtClone::mutate(Allocate* v) {
return new Allocate(v->buf());
}
Stmt* StmtClone::mutate(const Free* v) {
Stmt* StmtClone::mutate(Free* v) {
return new Free(v->buf());
}
Stmt* StmtClone::mutate(const Let* v) {
Stmt* StmtClone::mutate(Let* v) {
return new Let(v->var(), v->value());
}
Stmt* StmtClone::mutate(const Cond* v) {
Stmt* StmtClone::mutate(Cond* v) {
Stmt* true_old = v->true_stmt();
Stmt* false_old = v->false_stmt();

View File

@ -57,52 +57,51 @@ class ExternalCall;
class TORCH_API IRMutator {
public:
virtual ~IRMutator() = default;
virtual const Expr* mutate(const Add* v);
virtual const Expr* mutate(const Sub* v);
virtual const Expr* mutate(const Mul* v);
virtual const Expr* mutate(const Div* v);
virtual const Expr* mutate(const Mod* v);
virtual const Expr* mutate(const Max* v);
virtual const Expr* mutate(const Min* v);
virtual const Expr* mutate(const And* v);
virtual const Expr* mutate(const Or* v);
virtual const Expr* mutate(const Xor* v);
virtual const Expr* mutate(const Lshift* v);
virtual const Expr* mutate(const Rshift* v);
virtual const Expr* mutate(const CompareSelect* v);
#define IMM_MUTATE_DECLARE(Type, Name) \
virtual const Expr* mutate(const Name##Imm* v);
virtual Expr* mutate(Add* v);
virtual Expr* mutate(Sub* v);
virtual Expr* mutate(Mul* v);
virtual Expr* mutate(Div* v);
virtual Expr* mutate(Mod* v);
virtual Expr* mutate(Max* v);
virtual Expr* mutate(Min* v);
virtual Expr* mutate(And* v);
virtual Expr* mutate(Or* v);
virtual Expr* mutate(Xor* v);
virtual Expr* mutate(Lshift* v);
virtual Expr* mutate(Rshift* v);
virtual Expr* mutate(CompareSelect* v);
#define IMM_MUTATE_DECLARE(Type, Name) virtual Expr* mutate(Name##Imm* v);
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE);
#undef IMM_MUTATE_DECLARE
virtual const Expr* mutate(const Cast* v);
virtual const Expr* mutate(const BitCast* v);
virtual const Expr* mutate(const Var* v);
virtual const Expr* mutate(Buf* v);
virtual const Expr* mutate(const Ramp* v);
virtual const Expr* mutate(const Load* v);
virtual const Expr* mutate(const Broadcast* v);
virtual const Expr* mutate(const IfThenElse* v);
virtual const Expr* mutate(const Intrinsics* v);
virtual Expr* mutate(Cast* v);
virtual Expr* mutate(BitCast* v);
virtual Expr* mutate(Var* v);
virtual Expr* mutate(Buf* v);
virtual Expr* mutate(Ramp* v);
virtual Expr* mutate(Load* v);
virtual Expr* mutate(Broadcast* v);
virtual Expr* mutate(IfThenElse* v);
virtual Expr* mutate(Intrinsics* v);
virtual const Expr* mutate(const Term* v);
virtual const Expr* mutate(const Polynomial* v);
virtual const Expr* mutate(const RoundOff* v);
virtual const Expr* mutate(const MaxTerm* v);
virtual const Expr* mutate(const MinTerm* v);
virtual Expr* mutate(Term* v);
virtual Expr* mutate(Polynomial* v);
virtual Expr* mutate(RoundOff* v);
virtual Expr* mutate(MaxTerm* v);
virtual Expr* mutate(MinTerm* v);
virtual const Expr* mutate(const ReduceOp* v);
virtual Expr* mutate(ReduceOp* v);
virtual Stmt* mutate(const For* v);
virtual Stmt* mutate(const Block* v);
virtual Stmt* mutate(const Store* v);
virtual Stmt* mutate(const AtomicAdd* v);
virtual Stmt* mutate(const SyncThreads* v);
virtual Stmt* mutate(const ExternalCall* v);
virtual Stmt* mutate(For* v);
virtual Stmt* mutate(Block* v);
virtual Stmt* mutate(Store* v);
virtual Stmt* mutate(AtomicAdd* v);
virtual Stmt* mutate(SyncThreads* v);
virtual Stmt* mutate(ExternalCall* v);
virtual Stmt* mutate(const Allocate* v);
virtual Stmt* mutate(const Free* v);
virtual Stmt* mutate(const Let* v);
virtual Stmt* mutate(const Cond* v);
virtual Stmt* mutate(Allocate* v);
virtual Stmt* mutate(Free* v);
virtual Stmt* mutate(Let* v);
virtual Stmt* mutate(Cond* v);
};
} // namespace tensorexpr

View File

@ -18,11 +18,11 @@ void IRPrinter::print(ExprHandle expr) {
expr.node()->accept(this);
}
void IRPrinter::print(const Expr& expr) {
void IRPrinter::print(Expr& expr) {
expr.accept(this);
}
void IRPrinter::print(const Stmt& stmt) {
void IRPrinter::print(Stmt& stmt) {
stmt.accept(this);
}
@ -30,7 +30,7 @@ void IRPrinter::print(const Stmt& stmt) {
// we need to look at the operator precedence to make the output simpler.
template <typename Op>
void visitBinaryOp(
const BinaryOpNode<Op>* v,
BinaryOpNode<Op>* v,
const std::string& op_str,
IRPrinter* printer,
bool parens = true) {
@ -58,43 +58,43 @@ void visitBinaryOp(
}
}
void IRPrinter::visit(const Add* v) {
void IRPrinter::visit(Add* v) {
visitBinaryOp(v, "+", this);
}
void IRPrinter::visit(const Sub* v) {
void IRPrinter::visit(Sub* v) {
visitBinaryOp(v, "-", this);
}
void IRPrinter::visit(const Mul* v) {
void IRPrinter::visit(Mul* v) {
visitBinaryOp(v, "*", this);
}
void IRPrinter::visit(const Div* v) {
void IRPrinter::visit(Div* v) {
visitBinaryOp(v, "/", this);
}
void IRPrinter::visit(const And* v) {
void IRPrinter::visit(And* v) {
visitBinaryOp(v, "&", this);
}
void IRPrinter::visit(const Or* v) {
void IRPrinter::visit(Or* v) {
visitBinaryOp(v, "|", this);
}
void IRPrinter::visit(const Xor* v) {
void IRPrinter::visit(Xor* v) {
visitBinaryOp(v, "^", this);
}
void IRPrinter::visit(const Lshift* v) {
void IRPrinter::visit(Lshift* v) {
visitBinaryOp(v, "<<", this);
}
void IRPrinter::visit(const Rshift* v) {
void IRPrinter::visit(Rshift* v) {
visitBinaryOp(v, ">>", this);
}
void IRPrinter::visit(const Mod* v) {
void IRPrinter::visit(Mod* v) {
if (v->dtype().is_integral()) {
visitBinaryOp(v, "%", this);
} else if (v->dtype().is_floating_point()) {
@ -104,7 +104,7 @@ void IRPrinter::visit(const Mod* v) {
}
}
void IRPrinter::visit(const Max* v) {
void IRPrinter::visit(Max* v) {
os() << "Max(";
v->lhs()->accept(this);
os() << ", ";
@ -112,7 +112,7 @@ void IRPrinter::visit(const Max* v) {
os() << ", " << (unsigned int)v->propagate_nans() << ")";
}
void IRPrinter::visit(const Min* v) {
void IRPrinter::visit(Min* v) {
os() << "Min(";
v->lhs()->accept(this);
os() << ", ";
@ -120,7 +120,7 @@ void IRPrinter::visit(const Min* v) {
os() << ", " << (unsigned int)v->propagate_nans() << ")";
}
void IRPrinter::visit(const CompareSelect* v) {
void IRPrinter::visit(CompareSelect* v) {
CompareSelectOperation cmp_op = v->compare_select_op();
int self_prec = getPrecedence(v->expr_type());
int lhs_prec = getPrecedence(v->lhs()->expr_type());
@ -165,7 +165,7 @@ void IRPrinter::visit(const CompareSelect* v) {
}
os() << " ? ";
auto withParens = [&](const Expr* e) {
auto withParens = [&](Expr* e) {
auto prec = getPrecedence(e->expr_type());
if (prec >= self_prec) {
os() << "(";
@ -219,30 +219,30 @@ static void formatImm(std::ostream& os, T v) {
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT);
#undef IMM_PRINT_VISIT
void IRPrinter::visit(const Cast* v) {
void IRPrinter::visit(Cast* v) {
auto dtype = v->dtype();
os() << dtypeToCppString(dtype) << "(";
v->src_value()->accept(this);
os() << ")";
}
void IRPrinter::visit(const Var* v) {
void IRPrinter::visit(Var* v) {
os() << name_manager_.get_unique_name(v);
}
void IRPrinter::visit(const Ramp* v) {
void IRPrinter::visit(Ramp* v) {
os() << "Ramp(" << *v->base() << ", " << *v->stride() << ", " << v->lanes()
<< ")";
}
void IRPrinter::visit(const Load* v) {
void IRPrinter::visit(Load* v) {
// TODO: support the mask case
if (v->indices().size() == 0) {
os() << *v->base_handle();
} else {
os() << *v->base_handle() << "[";
size_t i = 0;
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
if (i++) {
os() << ", ";
}
@ -255,18 +255,18 @@ void IRPrinter::visit(const Load* v) {
}
}
void IRPrinter::visit(const Broadcast* v) {
void IRPrinter::visit(Broadcast* v) {
os() << "Broadcast(" << *v->value() << ", " << v->lanes() << ")";
}
void IRPrinter::visit(const IfThenElse* v) {
void IRPrinter::visit(IfThenElse* v) {
os() << "IfThenElse(" << *v->condition() << ", " << *v->true_value() << ", "
<< *v->false_value() << ")";
}
void IRPrinter::visit(const Intrinsics* v) {
void IRPrinter::visit(Intrinsics* v) {
os() << v->func_name() << "(";
for (const auto i : c10::irange(v->nparams())) {
for (auto i : c10::irange(v->nparams())) {
if (i > 0) {
os() << ", ";
}
@ -275,7 +275,7 @@ void IRPrinter::visit(const Intrinsics* v) {
os() << ")";
}
void IRPrinter::visit(const Term* v) {
void IRPrinter::visit(Term* v) {
os() << "Term(";
v->scalar()->accept(this);
for (auto* t : v->variables()) {
@ -285,7 +285,7 @@ void IRPrinter::visit(const Term* v) {
os() << ")";
}
void IRPrinter::visit(const Polynomial* v) {
void IRPrinter::visit(Polynomial* v) {
bool first = true;
os() << "Polynomial(";
for (auto* t : v->variables()) {
@ -303,7 +303,7 @@ void IRPrinter::visit(const Polynomial* v) {
os() << ")";
}
void IRPrinter::visit(const RoundOff* v) {
void IRPrinter::visit(RoundOff* v) {
os() << "RoundOff(";
v->lhs()->accept(this);
os() << ", ";
@ -311,7 +311,7 @@ void IRPrinter::visit(const RoundOff* v) {
os() << ")";
}
void IRPrinter::visit(const MaxTerm* v) {
void IRPrinter::visit(MaxTerm* v) {
os() << "MaxTerm(";
if (v->scalar()) {
v->scalar()->accept(this);
@ -326,7 +326,7 @@ void IRPrinter::visit(const MaxTerm* v) {
os() << ")";
}
void IRPrinter::visit(const MinTerm* v) {
void IRPrinter::visit(MinTerm* v) {
os() << "MinTerm(";
if (v->scalar()) {
v->scalar()->accept(this);
@ -341,7 +341,7 @@ void IRPrinter::visit(const MinTerm* v) {
os() << ")";
}
void IRPrinter::visit(const ReduceOp* v) {
void IRPrinter::visit(ReduceOp* v) {
os() << "ReduceOp(";
os() << *v->body() << ", ";
@ -363,7 +363,7 @@ void IRPrinter::visit(const ReduceOp* v) {
// each statement in a `Block` the printer will insert indentation before
// the statement and a newline after the statement.
void IRPrinter::visit(const Store* v) {
void IRPrinter::visit(Store* v) {
// TODO: handle the mask
if (v->indices().size() == 0) {
os() << *v->base_handle() << " = " << *v->value() << ";";
@ -372,7 +372,7 @@ void IRPrinter::visit(const Store* v) {
os() << *v->base_handle() << "[";
size_t i = 0;
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
if (i++) {
os() << ", ";
}
@ -384,8 +384,8 @@ void IRPrinter::visit(const Store* v) {
os() << "] = " << *v->value() << ";";
}
void IRPrinter::visit(const For* v) {
const Var* var = v->var();
void IRPrinter::visit(For* v) {
Var* var = v->var();
VarHandle vv(var);
os() << "for (" << dtypeToCppString(var->dtype()) << " " << vv << " = "
<< ExprHandle(v->start()) << "; " << vv << " < " << ExprHandle(v->stop())
@ -401,7 +401,7 @@ void IRPrinter::visit(const For* v) {
}
}
void IRPrinter::visit(const Block* v) {
void IRPrinter::visit(Block* v) {
os() << "{\n";
indent_++;
@ -414,12 +414,12 @@ void IRPrinter::visit(const Block* v) {
os() << "}";
}
void IRPrinter::visit(const Allocate* v) {
void IRPrinter::visit(Allocate* v) {
os() << "Allocate(" << *v->buffer_var()
<< "); // dtype=" << dtypeToCppString(v->dtype());
os() << ", dims=[";
const std::vector<const Expr*>& dims = v->dims();
for (const auto i : c10::irange(dims.size())) {
const std::vector<Expr*>& dims = v->dims();
for (auto i : c10::irange(dims.size())) {
if (i != 0) {
os() << ", ";
}
@ -428,18 +428,18 @@ void IRPrinter::visit(const Allocate* v) {
os() << "]";
}
void IRPrinter::visit(const Free* v) {
void IRPrinter::visit(Free* v) {
os() << "Free(" << *v->buffer_var() << ");";
}
void IRPrinter::visit(const Let* v) {
void IRPrinter::visit(Let* v) {
os() << dtypeToCppString(v->dtype()) << " " << *v->var();
os() << " = " << *v->value();
os() << ";";
}
void IRPrinter::visit(const Cond* v) {
const Expr* cond = v->condition();
void IRPrinter::visit(Cond* v) {
Expr* cond = v->condition();
Stmt* true_stmt = v->true_stmt();
Stmt* false_stmt = v->false_stmt();
if (!true_stmt) {
@ -455,10 +455,10 @@ void IRPrinter::visit(const Cond* v) {
}
}
void IRPrinter::visit(const AtomicAdd* v) {
void IRPrinter::visit(AtomicAdd* v) {
os() << "atomicAdd(&" << *v->base_handle() << "[";
size_t i = 0;
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
if (i++) {
os() << ", ";
}
@ -470,16 +470,16 @@ void IRPrinter::visit(const AtomicAdd* v) {
os() << "], " << *v->value() << ");";
}
void IRPrinter::visit(const SyncThreads* v) {
void IRPrinter::visit(SyncThreads* v) {
os() << "__syncthreads();";
}
void IRPrinter::visit(const ExternalCall* v) {
void IRPrinter::visit(ExternalCall* v) {
os() << *v->buf() << " = " << v->func_name() << "(";
os() << "buf_args={";
int i = 0;
for (const Buf* buf_arg : v->buf_args()) {
for (Buf* buf_arg : v->buf_args()) {
if (i++ > 0) {
os() << ", ";
}
@ -488,7 +488,7 @@ void IRPrinter::visit(const ExternalCall* v) {
os() << "}, args={";
i = 0;
for (const Expr* arg : v->args()) {
for (Expr* arg : v->args()) {
if (i++ > 0) {
os() << ", ";
}
@ -504,11 +504,12 @@ void IRPrinter::emitIndent() {
std::ostream& operator<<(std::ostream& stream, const ExprHandle& expr) {
IRPrinter::PrinterStream* printer_stream =
dynamic_cast<IRPrinter::PrinterStream*>(&stream);
ExprHandle& mutable_expr = const_cast<ExprHandle&>(expr);
if (printer_stream != nullptr) {
expr.node()->accept(printer_stream->printer());
mutable_expr.node()->accept(printer_stream->printer());
} else {
IRPrinter p(stream);
p.print(expr);
p.print(mutable_expr);
}
return stream;
}
@ -516,11 +517,12 @@ std::ostream& operator<<(std::ostream& stream, const ExprHandle& expr) {
std::ostream& operator<<(std::ostream& stream, const Expr& expr) {
IRPrinter::PrinterStream* printer_stream =
dynamic_cast<IRPrinter::PrinterStream*>(&stream);
Expr& mutable_expr = const_cast<Expr&>(expr);
if (printer_stream != nullptr) {
expr.accept(printer_stream->printer());
mutable_expr.accept(printer_stream->printer());
} else {
IRPrinter p(stream);
p.print(expr);
p.print(mutable_expr);
}
return stream;
}
@ -528,11 +530,12 @@ std::ostream& operator<<(std::ostream& stream, const Expr& expr) {
std::ostream& operator<<(std::ostream& stream, const Stmt& stmt) {
IRPrinter::PrinterStream* printer_stream =
dynamic_cast<IRPrinter::PrinterStream*>(&stream);
Stmt& mutable_stmt = const_cast<Stmt&>(stmt);
if (printer_stream != nullptr) {
stmt.accept(printer_stream->printer());
mutable_stmt.accept(printer_stream->printer());
} else {
IRPrinter p(stream);
p.print(stmt);
p.print(mutable_stmt);
}
return stream;
}
@ -544,8 +547,9 @@ std::ostream& operator<<(std::ostream& stream, const Tensor& t) {
void print(const Expr* expr) {
if (expr) {
Expr* mutable_expr = const_cast<Expr*>(expr);
IRPrinter p(std::cout);
p.print(*expr);
p.print(*mutable_expr);
} else {
std::cout << "(null expr)";
}
@ -554,8 +558,9 @@ void print(const Expr* expr) {
void print(const Stmt* stmt) {
if (stmt) {
Stmt* mutable_stmt = const_cast<Stmt*>(stmt);
IRPrinter p(std::cout);
p.print(*stmt);
p.print(*mutable_stmt);
} else {
std::cout << "(null stmt)\n";
}
@ -589,7 +594,7 @@ std::string to_string(const Tensor* t) {
std::ostringstream oss;
// TODO: move this to Buf printer
oss << "Tensor " << t->buf()->name_hint() << "[";
for (const auto i : c10::irange(t->buf()->ndim())) {
for (auto i : c10::irange(t->buf()->ndim())) {
if (i != 0) {
oss << ", ";
}

View File

@ -17,48 +17,48 @@ class TORCH_API IRPrinter : public IRVisitor {
explicit IRPrinter(std::ostream& os) : printer_os_(this, os) {}
void print(ExprHandle);
void print(const Expr&);
void print(const Stmt&);
void visit(const Add* v) override;
void visit(const Sub* v) override;
void visit(const Mul* v) override;
void visit(const Div* v) override;
void visit(const Mod* v) override;
void visit(const Max* v) override;
void visit(const Min* v) override;
void visit(const And* v) override;
void visit(const Or* v) override;
void visit(const Xor* v) override;
void visit(const Lshift* v) override;
void visit(const Rshift* v) override;
void visit(const CompareSelect* v) override;
void print(Expr&);
void print(Stmt&);
void visit(Add* v) override;
void visit(Sub* v) override;
void visit(Mul* v) override;
void visit(Div* v) override;
void visit(Mod* v) override;
void visit(Max* v) override;
void visit(Min* v) override;
void visit(And* v) override;
void visit(Or* v) override;
void visit(Xor* v) override;
void visit(Lshift* v) override;
void visit(Rshift* v) override;
void visit(CompareSelect* v) override;
#define IMM_PRINT_VISIT(Type, Name) void visit(const Name##Imm* v) override;
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT);
#undef IMM_PRINT_VISIT
void visit(const Cast* v) override;
void visit(const Var* v) override;
void visit(const Ramp* v) override;
void visit(const Load* v) override;
void visit(const Broadcast* v) override;
void visit(const IfThenElse* v) override;
void visit(const Intrinsics* v) override;
void visit(const Term* v) override;
void visit(const Polynomial* v) override;
void visit(const RoundOff* v) override;
void visit(const MaxTerm* v) override;
void visit(const MinTerm* v) override;
void visit(const ReduceOp* v) override;
void visit(Cast* v) override;
void visit(Var* v) override;
void visit(Ramp* v) override;
void visit(Load* v) override;
void visit(Broadcast* v) override;
void visit(IfThenElse* v) override;
void visit(Intrinsics* v) override;
void visit(Term* v) override;
void visit(Polynomial* v) override;
void visit(RoundOff* v) override;
void visit(MaxTerm* v) override;
void visit(MinTerm* v) override;
void visit(ReduceOp* v) override;
void visit(const AtomicAdd* v) override;
void visit(const SyncThreads* v) override;
void visit(const ExternalCall* v) override;
void visit(const Store* v) override;
void visit(const For* v) override;
void visit(const Cond* v) override;
void visit(const Block* v) override;
void visit(const Allocate* v) override;
void visit(const Free* v) override;
void visit(const Let* v) override;
void visit(AtomicAdd* v) override;
void visit(SyncThreads* v) override;
void visit(ExternalCall* v) override;
void visit(Store* v) override;
void visit(For* v) override;
void visit(Cond* v) override;
void visit(Block* v) override;
void visit(Allocate* v) override;
void visit(Free* v) override;
void visit(Let* v) override;
// A child class may have a difference rule for generating dtype
// string, e.g. CUDA needs int64_t to be generated as long long.

File diff suppressed because it is too large Load Diff

View File

@ -25,7 +25,7 @@ namespace tensorexpr {
// A bunch of helpers for determine the Dtype of the output of a multi argument
// Term or Polynomial.
template <class ExprType>
Dtype promoteTypesVec(const Expr* s, std::vector<const ExprType*>& v) {
Dtype promoteTypesVec(Expr* s, std::vector<ExprType*>& v) {
Dtype t = s->dtype();
bool first = true;
@ -40,7 +40,7 @@ Dtype promoteTypesVec(const Expr* s, std::vector<const ExprType*>& v) {
}
template <class ExprType>
Dtype promoteTypesVec(std::vector<const ExprType*>& v) {
Dtype promoteTypesVec(std::vector<ExprType*>& v) {
if (v.empty()) {
throw malformed_input("empty list of types");
}
@ -54,8 +54,8 @@ Dtype promoteTypesVec(std::vector<const ExprType*>& v) {
template <class ExprType>
Dtype promoteTypesMap(
const Expr* s,
std::unordered_map<SimplifierHashType, const ExprType*>& m) {
Expr* s,
std::unordered_map<SimplifierHashType, ExprType*>& m) {
Dtype t = s->dtype();
bool first = true;
for (auto& e : m) {
@ -69,12 +69,12 @@ Dtype promoteTypesMap(
}
template <class ExprType>
Dtype promoteTypesVar(const ExprType* e) {
Dtype promoteTypesVar(ExprType* e) {
return e->dtype();
}
template <class ExprType, class... Args>
Dtype promoteTypesVar(const ExprType* e, Args... es) {
Dtype promoteTypesVar(ExprType* e, Args... es) {
Dtype lhs = e->dtype();
Dtype rhs = promoteTypesVar(es...);
if (e->isConstant()) {
@ -85,10 +85,10 @@ Dtype promoteTypesVar(const ExprType* e, Args... es) {
}
// Creates a new Expr of the given type with the provided lhs and rhs.
inline const Expr* newBinaryOpOfType(
inline Expr* newBinaryOpOfType(
IRNodeType expr_type,
const Expr* lhs,
const Expr* rhs,
Expr* lhs,
Expr* rhs,
bool option) {
switch (expr_type) {
// NOLINTNEXTLINE(bugprone-branch-clone)
@ -123,7 +123,7 @@ inline const Expr* newBinaryOpOfType(
// Uses the evaluator to fold an Expression with constant terms.
// E.g. evaluateOp(Add(3, 4)) => 7.
// Expr v must not have any unbound Vars.
inline Expr* evaluateOp(const Expr* v) {
inline Expr* evaluateOp(Expr* v) {
ExprHandle handle(v);
ExprEval<SimpleIREvaluator> eval(handle);
@ -148,7 +148,7 @@ class Term : public ExprNode<Term> {
public:
template <class... Args>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Term(HashProvider& hasher, const Expr* s, Args... ts)
Term(HashProvider& hasher, Expr* s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addComponent(ts...);
@ -156,7 +156,7 @@ class Term : public ExprNode<Term> {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Term(HashProvider& hasher, const Expr* s, std::vector<const Expr*> v)
Term(HashProvider& hasher, Expr* s, std::vector<Expr*> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
@ -168,8 +168,8 @@ class Term : public ExprNode<Term> {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Term(
HashProvider& hasher,
const Expr* s,
std::unordered_map<SimplifierHashType, const Expr*> varmap)
Expr* s,
std::unordered_map<SimplifierHashType, Expr*> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addComponent(p.second);
@ -177,10 +177,10 @@ class Term : public ExprNode<Term> {
sort();
}
const Expr* scalar() const {
Expr* scalar() const {
return scalar_;
}
const std::vector<const Expr*>& variables() const {
const std::vector<Expr*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
@ -192,16 +192,16 @@ class Term : public ExprNode<Term> {
SimplifierHashType hashVars() const;
private:
std::vector<const Expr*> variables_;
const Expr* scalar_;
std::vector<Expr*> variables_;
Expr* scalar_;
HashProvider& hasher_;
void addComponent() {}
void addComponent(const Expr* e) {
void addComponent(Expr* e) {
variables_.push_back(e);
}
template <class... Es>
void addComponent(const Expr* e, Es... es) {
void addComponent(Expr* e, Es... es) {
addComponent(e);
addComponent(es...);
}
@ -217,7 +217,7 @@ class Polynomial : public ExprNode<Polynomial> {
public:
template <class... Args>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Polynomial(HashProvider& hasher, const Expr* s, Args... ts)
Polynomial(HashProvider& hasher, Expr* s, Args... ts)
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
CHECK(s->isConstant());
addTerm(ts...);
@ -225,7 +225,7 @@ class Polynomial : public ExprNode<Polynomial> {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Polynomial(HashProvider& hasher, const Expr* s, std::vector<const Term*> v)
Polynomial(HashProvider& hasher, Expr* s, std::vector<Term*> v)
: ExprNodeBase(promoteTypesVec(s, v)),
variables_(std::move(v)),
scalar_(s),
@ -235,7 +235,7 @@ class Polynomial : public ExprNode<Polynomial> {
// Helper constructor for list of terms with no scalar component.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Polynomial(HashProvider& hasher, std::vector<const Term*> terms)
Polynomial(HashProvider& hasher, std::vector<Term*> terms)
: ExprNodeBase(promoteTypesVec(terms)),
variables_(std::move(terms)),
scalar_(getImmediateByType(dtype(), 0)),
@ -248,8 +248,8 @@ class Polynomial : public ExprNode<Polynomial> {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Polynomial(
HashProvider& hasher,
const Expr* s,
std::unordered_map<SimplifierHashType, const Term*> varmap)
Expr* s,
std::unordered_map<SimplifierHashType, Term*> varmap)
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
for (auto& p : varmap) {
addTerm(p.second);
@ -257,10 +257,10 @@ class Polynomial : public ExprNode<Polynomial> {
sort();
}
const Expr* scalar() const {
Expr* scalar() const {
return scalar_;
}
const std::vector<const Term*>& variables() const {
const std::vector<Term*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
@ -270,15 +270,15 @@ class Polynomial : public ExprNode<Polynomial> {
SimplifierHashType hashVars() const;
private:
std::vector<const Term*> variables_;
const Expr* scalar_;
std::vector<Term*> variables_;
Expr* scalar_;
HashProvider& hasher_;
void addTerm(const Term* t) {
void addTerm(Term* t) {
variables_.push_back(t);
}
template <class... Ts>
void addTerm(const Term* t, Ts... ts) {
void addTerm(Term* t, Ts... ts) {
addTerm(t);
addTerm(ts...);
}
@ -289,15 +289,14 @@ class Polynomial : public ExprNode<Polynomial> {
class RoundOff : public BinaryOpNode<RoundOff> {
public:
RoundOff(const Expr* lhs, const Expr* rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kOther) {}
RoundOff(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kOther) {}
};
class MaxTerm : public ExprNode<MaxTerm> {
public:
template <class... Args>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
MaxTerm(HashProvider& hasher, const Expr* s, bool p, Args... ts)
MaxTerm(HashProvider& hasher, Expr* s, bool p, Args... ts)
: ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
scalar_(s),
hasher_(hasher),
@ -307,11 +306,7 @@ class MaxTerm : public ExprNode<MaxTerm> {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
MaxTerm(
HashProvider& hasher,
const Expr* s,
bool p,
std::vector<const Expr*> v)
MaxTerm(HashProvider& hasher, Expr* s, bool p, std::vector<Expr*> v)
: ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
variables_(std::move(v)),
scalar_(s),
@ -324,10 +319,10 @@ class MaxTerm : public ExprNode<MaxTerm> {
return propagate_nans_;
}
const Expr* scalar() const {
Expr* scalar() const {
return scalar_;
}
const std::vector<const Expr*>& variables() const {
const std::vector<Expr*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
@ -335,17 +330,17 @@ class MaxTerm : public ExprNode<MaxTerm> {
}
private:
std::vector<const Expr*> variables_;
const Expr* scalar_;
std::vector<Expr*> variables_;
Expr* scalar_;
HashProvider& hasher_;
bool propagate_nans_;
void addComponent() {}
void addComponent(const Expr* e) {
void addComponent(Expr* e) {
variables_.push_back(e);
}
template <class... Es>
void addComponent(const Expr* e, Es... es) {
void addComponent(Expr* e, Es... es) {
addComponent(e);
addComponent(es...);
}
@ -358,7 +353,7 @@ class MinTerm : public ExprNode<MinTerm> {
public:
template <class... Args>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
MinTerm(HashProvider& hasher, const Expr* s, bool p, Args... ts)
MinTerm(HashProvider& hasher, Expr* s, bool p, Args... ts)
: ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
scalar_(s),
hasher_(hasher),
@ -368,11 +363,7 @@ class MinTerm : public ExprNode<MinTerm> {
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
MinTerm(
HashProvider& hasher,
const Expr* s,
bool p,
std::vector<const Expr*> v)
MinTerm(HashProvider& hasher, Expr* s, bool p, std::vector<Expr*> v)
: ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
variables_(std::move(v)),
scalar_(s),
@ -385,10 +376,10 @@ class MinTerm : public ExprNode<MinTerm> {
return propagate_nans_;
}
const Expr* scalar() const {
Expr* scalar() const {
return scalar_;
}
const std::vector<const Expr*>& variables() const {
const std::vector<Expr*>& variables() const {
return variables_;
}
HashProvider& hasher() const {
@ -396,17 +387,17 @@ class MinTerm : public ExprNode<MinTerm> {
}
private:
std::vector<const Expr*> variables_;
const Expr* scalar_;
std::vector<Expr*> variables_;
Expr* scalar_;
HashProvider& hasher_;
bool propagate_nans_;
void addComponent() {}
void addComponent(const Expr* e) {
void addComponent(Expr* e) {
variables_.push_back(e);
}
template <class... Es>
void addComponent(const Expr* e, Es... es) {
void addComponent(Expr* e, Es... es) {
addComponent(e);
addComponent(es...);
}
@ -416,16 +407,15 @@ class MinTerm : public ExprNode<MinTerm> {
};
// Context-sensitive IR simplification
using VarBoundInfo =
std::unordered_map<const Var*, std::pair<const Expr*, const Expr*>>;
using VarBoundInfo = std::unordered_map<Var*, std::pair<Expr*, Expr*>>;
class TORCH_API SimplifierUnderContext : public IRMutator {
public:
~SimplifierUnderContext() override = default;
// Add boundary info for index variables in for-loops
Stmt* mutate(const For* v) override;
Stmt* mutate(For* v) override;
const Expr* mutate(const Div* v) override;
const Expr* mutate(const Mod* v) override;
Expr* mutate(Div* v) override;
Expr* mutate(Mod* v) override;
protected:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
@ -438,14 +428,14 @@ class TORCH_API PolynomialBase : public IRMutator {
public:
~PolynomialBase() override = default;
Stmt* mutate(const Block* v) override;
Stmt* mutate(Block* v) override;
Stmt* mutate(const Cond* v) override;
Stmt* mutate(Cond* v) override;
Stmt* mutate(const For* v) override;
Stmt* mutate(For* v) override;
// Trivially factorize terms by GCD of scalar components.
const Term* factorizePolynomial(const Polynomial* poly);
Term* factorizePolynomial(Polynomial* poly);
HashProvider& hasher() {
return hasher_;
@ -463,89 +453,89 @@ class TORCH_API PolynomialTransformer : public PolynomialBase {
// Inserts term into the provided map, in the case of a hash collision
// combines the term with the existing and updates the map.
void addOrUpdateTerm(
std::unordered_map<SimplifierHashType, const Term*>& varmap,
const Term* term);
std::unordered_map<SimplifierHashType, Term*>& varmap,
Term* term);
// Add Polynomial expressions, combining Terms representing the same
// variables.
const Expr* addPolynomials(const Polynomial* lhs, const Polynomial* rhs);
Expr* addPolynomials(Polynomial* lhs, Polynomial* rhs);
// Insert a new Term into the provided polynomial. If the new term has common
// variables to an existing term it is combined.
const Expr* insertTerm(const Polynomial* poly, const Term* term);
Expr* insertTerm(Polynomial* poly, Term* term);
// Merge and simplify addition.
const Expr* mutate(const Add* v) override;
Expr* mutate(Add* v) override;
// Subtract one term from another, cancelling if necessary.
const Expr* subTerms(const Term* lhs, const Term* rhs, bool negated);
Expr* subTerms(Term* lhs, Term* rhs, bool negated);
// Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
// possible.
const Expr* subPolynomials(const Polynomial* lhs, const Polynomial* rhs);
Expr* subPolynomials(Polynomial* lhs, Polynomial* rhs);
// Merge and simplify subtraction.
const Expr* mutate(const Sub* v) override;
Expr* mutate(Sub* v) override;
// Multiply two terms together, usually creating a new term with the variable
// lists concatenated.
const Term* mulTerms(const Term* lhs, const Term* rhs);
Term* mulTerms(Term* lhs, Term* rhs);
// Multiply a Polynomial by a Term.
const Expr* polyByTerm(const Polynomial* poly, const Term* term);
Expr* polyByTerm(Polynomial* poly, Term* term);
// Match a rounding pattern and create a RoundOff if found.
const Expr* isRoundOff(const Expr* lhs, const Expr* rhs);
Expr* isRoundOff(Expr* lhs, Expr* rhs);
// Inserts a new component into a term, simplifying if possible.
const Expr* insertIntoTerm(const Term* term, const Expr* expr);
Expr* insertIntoTerm(Term* term, Expr* expr);
// Merge and simplify multiplication.
const Expr* mutate(const Mul* v) override;
Expr* mutate(Mul* v) override;
const Expr* mutate(const Div* v) override;
Expr* mutate(Div* v) override;
const Expr* mutate(const Mod* v) override;
Expr* mutate(Mod* v) override;
const Expr* mutate(const And* v) override {
Expr* mutate(And* v) override {
return mutateBinaryOp(v, this);
}
const Expr* mutate(const Xor* v) override {
Expr* mutate(Xor* v) override {
return mutateBinaryOp(v, this);
}
const Expr* mutate(const Lshift* v) override {
Expr* mutate(Lshift* v) override {
return mutateBinaryOp(v, this);
}
const Expr* mutate(const Rshift* v) override {
Expr* mutate(Rshift* v) override {
return mutateBinaryOp(v, this);
}
const Expr* mutate(const Max* v) override;
Expr* mutate(Max* v) override;
const Expr* mutate(const Min* v) override;
Expr* mutate(Min* v) override;
const Expr* mutate(const CompareSelect* v) override;
Expr* mutate(CompareSelect* v) override;
const Expr* mutate(const Intrinsics* v) override;
Expr* mutate(Intrinsics* v) override;
const Expr* mutate(const Cast* v) override;
Expr* mutate(Cast* v) override;
const Expr* mutate(const IfThenElse* v) override;
Expr* mutate(IfThenElse* v) override;
template <typename Op>
static const Expr* mutateBinaryOp(
const BinaryOpNode<Op>* v,
static Expr* mutateBinaryOp(
BinaryOpNode<Op>* v,
IRMutator* mutator,
bool option = false) {
const Expr* lhs = v->lhs();
const Expr* rhs = v->rhs();
const Expr* lhs_new = lhs->accept_mutator(mutator);
const Expr* rhs_new = rhs->accept_mutator(mutator);
Expr* lhs = v->lhs();
Expr* rhs = v->rhs();
Expr* lhs_new = lhs->accept_mutator(mutator);
Expr* rhs_new = rhs->accept_mutator(mutator);
const Expr* node = v;
Expr* node = v;
if (lhs != lhs_new || rhs != rhs_new) {
node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option);
@ -559,7 +549,7 @@ class TORCH_API PolynomialTransformer : public PolynomialBase {
return evaluateOp(node);
}
static const Expr* simplify(const Expr* e);
static Expr* simplify(Expr* e);
static ExprHandle simplify(const ExprHandle& e);
static Stmt* simplify(Stmt* e);
};
@ -568,7 +558,7 @@ class TORCH_API PolynomialTransformer : public PolynomialBase {
// Does some simple factorization and reordering.
class TORCH_API TermExpander : public PolynomialBase {
PolynomialTransformer* simplifier_;
std::set<const Var*> eliminated_allocations_;
std::set<Var*> eliminated_allocations_;
public:
using PolynomialBase::mutate;
@ -579,33 +569,33 @@ class TORCH_API TermExpander : public PolynomialBase {
}
// Expand Terms out to a series of Muls.
const Expr* mutate(const Term* v) override;
Expr* mutate(Term* v) override;
// Expand Polynomials out to a series of Adds.
const Expr* mutate(const Polynomial* v) override;
Expr* mutate(Polynomial* v) override;
// Expand MaxTerms to a series of Max ops.
const Expr* mutate(const MaxTerm* v) override;
Expr* mutate(MaxTerm* v) override;
// Expand MinTerms to a series of Min ops.
const Expr* mutate(const MinTerm* v) override;
Expr* mutate(MinTerm* v) override;
// Expand RoundOff to it's component: Mul(Div(lhs, rhs), rhs).
const Expr* mutate(const RoundOff* v) override;
Expr* mutate(RoundOff* v) override;
// Eliminate zero length allocations.
Stmt* mutate(const Allocate* v) override;
Stmt* mutate(const Free* v) override;
Stmt* mutate(Allocate* v) override;
Stmt* mutate(Free* v) override;
// Override to enable condition fusing.
Block* fuseConditions(Block* v);
Stmt* fuseSyncThreads(Block* block);
Stmt* mutate(const Block* v) override;
Stmt* mutate(Block* v) override;
};
class TORCH_API IRSimplifier {
public:
static const Expr* simplify(const Expr* e) {
static Expr* simplify(Expr* e) {
SimplifierUnderContext ctxsimplifier;
e = e->accept_mutator(&ctxsimplifier);
@ -649,9 +639,9 @@ class TORCH_API IRSimplifier {
};
// Flattens the buf and performs the simplifier on the flattened dims.
const Expr* buf_flat_size(const Buf* v);
Expr* buf_flat_size(Buf* v);
// Returns true if expressions A and B can be simplified to an equal expression.
TORCH_API bool exprEquals(const Expr* A, const Expr* B);
TORCH_API bool exprEquals(Expr* A, Expr* B);
} // namespace tensorexpr
} // namespace jit

View File

@ -19,39 +19,39 @@ void verifyBitwiseOp(const BitwiseOpNode<Op>* v, IRVerifier* verifier) {
}
}
void IRVerifier::visit(const And* v) {
void IRVerifier::visit(And* v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
void IRVerifier::visit(const Or* v) {
void IRVerifier::visit(Or* v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
void IRVerifier::visit(const Xor* v) {
void IRVerifier::visit(Xor* v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
void IRVerifier::visit(const Lshift* v) {
void IRVerifier::visit(Lshift* v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
void IRVerifier::visit(const Rshift* v) {
void IRVerifier::visit(Rshift* v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
void IRVerifier::visit(const Mod* v) {
void IRVerifier::visit(Mod* v) {
if (!v->dtype().is_integral() && !v->dtype().is_floating_point()) {
throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype()));
}
IRVisitor::visit(v);
}
void IRVerifier::visit(const CompareSelect* v) {
void IRVerifier::visit(CompareSelect* v) {
if (v->ret_val1()->dtype() != v->ret_val2()->dtype()) {
throw malformed_ir("bad dtype in CompareSelect");
}
@ -61,15 +61,15 @@ void IRVerifier::visit(const CompareSelect* v) {
IRVisitor::visit(v);
}
void IRVerifier::visit(const Ramp* v) {
void IRVerifier::visit(Ramp* v) {
if (v->stride()->dtype() != v->base()->dtype()) {
throw malformed_ir("Bad stride in Ramp");
}
IRVisitor::visit(v);
}
void IRVerifier::visit(const Load* v) {
const auto indices = v->indices();
void IRVerifier::visit(Load* v) {
auto indices = v->indices();
if (indices.size() > 0 && v->buf()->base_handle()->dtype() != kHandle) {
throw malformed_ir(
"Load base handle dtype must be Handle", v->buf()->base_handle());
@ -94,7 +94,7 @@ void IRVerifier::visit(const Load* v) {
IRVisitor::visit(v);
}
void IRVerifier::visit(const IfThenElse* v) {
void IRVerifier::visit(IfThenElse* v) {
if (!v->condition()->dtype().is_integral()) {
throw unsupported_dtype();
}
@ -107,13 +107,13 @@ void IRVerifier::visit(const IfThenElse* v) {
IRVisitor::visit(v);
}
void IRVerifier::visit(const Intrinsics* v) {
void IRVerifier::visit(Intrinsics* v) {
// TODO: add a check for OpArgCount and op_type
IRVisitor::visit(v);
}
void IRVerifier::visit(const Store* v) {
const auto indices = v->indices();
void IRVerifier::visit(Store* v) {
auto indices = v->indices();
if (indices.size() > 0 && v->buf()->base_handle()->dtype() != kHandle) {
throw malformed_ir(
"Store base handle dtype must be Handle", v->buf()->base_handle());
@ -141,7 +141,7 @@ void IRVerifier::visit(const Store* v) {
IRVisitor::visit(v);
}
void IRVerifier::visit(const For* v) {
void IRVerifier::visit(For* v) {
if (!v->var()) {
throw malformed_ir("nullptr Var in For loop");
} else if (!v->start()) {
@ -154,7 +154,7 @@ void IRVerifier::visit(const For* v) {
IRVisitor::visit(v);
}
void IRVerifier::visit(const Block* v) {
void IRVerifier::visit(Block* v) {
for (Stmt* s : v->stmts()) {
if (s->get_parent() != v) {
throw malformed_ir("Broken child-parent link inside a Block");
@ -163,7 +163,7 @@ void IRVerifier::visit(const Block* v) {
IRVisitor::visit(v);
}
void IRVerifier::visit(const ExternalCall* v) {
void IRVerifier::visit(ExternalCall* v) {
IRVisitor::visit(v);
}
@ -172,7 +172,7 @@ void verify(Stmt* s) {
s->accept(&verifier);
}
void verify(const Expr* e) {
void verify(Expr* e) {
IRVerifier verifier;
e->accept(&verifier);
}

View File

@ -32,26 +32,26 @@ class TORCH_API IRVerifier : public IRVisitor {
public:
IRVerifier() = default;
void visit(const Mod* v) override;
void visit(const And* v) override;
void visit(const Or* v) override;
void visit(const Xor* v) override;
void visit(const Lshift* v) override;
void visit(const Rshift* v) override;
void visit(const CompareSelect* v) override;
void visit(const Ramp* v) override;
void visit(const Load* v) override;
void visit(const IfThenElse* v) override;
void visit(const Intrinsics* v) override;
void visit(Mod* v) override;
void visit(And* v) override;
void visit(Or* v) override;
void visit(Xor* v) override;
void visit(Lshift* v) override;
void visit(Rshift* v) override;
void visit(CompareSelect* v) override;
void visit(Ramp* v) override;
void visit(Load* v) override;
void visit(IfThenElse* v) override;
void visit(Intrinsics* v) override;
void visit(const ExternalCall* v) override;
void visit(const Store* v) override;
void visit(const For* v) override;
void visit(const Block* v) override;
void visit(ExternalCall* v) override;
void visit(Store* v) override;
void visit(For* v) override;
void visit(Block* v) override;
};
TORCH_API void verify(Stmt*);
TORCH_API void verify(const Expr*);
TORCH_API void verify(Expr*);
TORCH_API void verify(ExprHandle);
} // namespace tensorexpr

View File

@ -12,60 +12,60 @@ namespace jit {
namespace tensorexpr {
template <typename Op>
static void visit_binary_op(const BinaryOpNode<Op>* v, IRVisitor* visitor) {
static void visit_binary_op(BinaryOpNode<Op>* v, IRVisitor* visitor) {
v->lhs()->accept(visitor);
v->rhs()->accept(visitor);
}
void IRVisitor::visit(const Add* v) {
void IRVisitor::visit(Add* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Sub* v) {
void IRVisitor::visit(Sub* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Mul* v) {
void IRVisitor::visit(Mul* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Div* v) {
void IRVisitor::visit(Div* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Mod* v) {
void IRVisitor::visit(Mod* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Max* v) {
void IRVisitor::visit(Max* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Min* v) {
void IRVisitor::visit(Min* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const And* v) {
void IRVisitor::visit(And* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Or* v) {
void IRVisitor::visit(Or* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Xor* v) {
void IRVisitor::visit(Xor* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Lshift* v) {
void IRVisitor::visit(Lshift* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Rshift* v) {
void IRVisitor::visit(Rshift* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const CompareSelect* v) {
void IRVisitor::visit(CompareSelect* v) {
v->lhs()->accept(this);
v->rhs()->accept(this);
v->ret_val1()->accept(this);
@ -78,65 +78,65 @@ void IRVisitor::visit(const CompareSelect* v) {
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
#undef IMM_VISIT
void IRVisitor::visit(const Cast* v) {
void IRVisitor::visit(Cast* v) {
v->src_value()->accept(this);
}
void IRVisitor::visit(const BitCast* v) {
void IRVisitor::visit(BitCast* v) {
v->src_value()->accept(this);
}
void IRVisitor::visit(const Var* v) {}
void IRVisitor::visit(Var* v) {}
void IRVisitor::visit(const Ramp* v) {
void IRVisitor::visit(Ramp* v) {
v->base()->accept(this);
v->stride()->accept(this);
}
void IRVisitor::visit(const Load* v) {
void IRVisitor::visit(Load* v) {
v->buf()->accept(this);
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
ind->accept(this);
}
}
void IRVisitor::visit(const Buf* v) {
void IRVisitor::visit(Buf* v) {
v->base_handle()->accept(this);
}
void IRVisitor::visit(const Store* v) {
void IRVisitor::visit(Store* v) {
v->buf()->accept(this);
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
ind->accept(this);
}
v->value()->accept(this);
}
void IRVisitor::visit(const AtomicAdd* v) {
void IRVisitor::visit(AtomicAdd* v) {
v->buf()->accept(this);
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
ind->accept(this);
}
v->value()->accept(this);
}
void IRVisitor::visit(const SyncThreads* v) {}
void IRVisitor::visit(SyncThreads* v) {}
void IRVisitor::visit(const ExternalCall* v) {
void IRVisitor::visit(ExternalCall* v) {
v->buf()->accept(this);
for (const Buf* buf_arg : v->buf_args()) {
for (Buf* buf_arg : v->buf_args()) {
buf_arg->accept(this);
}
for (const Expr* arg : v->args()) {
for (Expr* arg : v->args()) {
arg->accept(this);
}
}
void IRVisitor::visit(const Block* v) {
void IRVisitor::visit(Block* v) {
for (Stmt* s : *v) {
s->accept(this);
}
}
void IRVisitor::visit(const For* v) {
void IRVisitor::visit(For* v) {
v->var()->accept(this);
v->start()->accept(this);
v->stop()->accept(this);
@ -145,41 +145,41 @@ void IRVisitor::visit(const For* v) {
}
}
void IRVisitor::visit(const Broadcast* v) {
void IRVisitor::visit(Broadcast* v) {
v->value()->accept(this);
}
void IRVisitor::visit(const IfThenElse* v) {
void IRVisitor::visit(IfThenElse* v) {
v->condition()->accept(this);
v->true_value()->accept(this);
v->false_value()->accept(this);
}
void IRVisitor::visit(const Intrinsics* v) {
for (const auto i : c10::irange(v->nparams())) {
void IRVisitor::visit(Intrinsics* v) {
for (auto i : c10::irange(v->nparams())) {
v->param(i)->accept(this);
}
}
void IRVisitor::visit(const Allocate* v) {
void IRVisitor::visit(Allocate* v) {
v->buffer_var()->accept(this);
std::vector<const Expr*> dims = v->dims();
for (const Expr* dim : dims) {
std::vector<Expr*> dims = v->dims();
for (Expr* dim : dims) {
dim->accept(this);
}
}
void IRVisitor::visit(const Free* v) {
void IRVisitor::visit(Free* v) {
v->buffer_var()->accept(this);
}
void IRVisitor::visit(const Let* v) {
void IRVisitor::visit(Let* v) {
v->var()->accept(this);
v->value()->accept(this);
}
void IRVisitor::visit(const Cond* v) {
const Expr* condition = v->condition();
void IRVisitor::visit(Cond* v) {
Expr* condition = v->condition();
Stmt* true_stmt = v->true_stmt();
Stmt* false_stmt = v->false_stmt();
condition->accept(this);
@ -191,26 +191,26 @@ void IRVisitor::visit(const Cond* v) {
}
}
void IRVisitor::visit(const Term* v) {
void IRVisitor::visit(Term* v) {
v->scalar()->accept(this);
for (auto* t : v->variables()) {
t->accept(this);
}
}
void IRVisitor::visit(const Polynomial* v) {
void IRVisitor::visit(Polynomial* v) {
v->scalar()->accept(this);
for (auto* t : v->variables()) {
t->accept(this);
}
}
void IRVisitor::visit(const RoundOff* v) {
void IRVisitor::visit(RoundOff* v) {
v->lhs()->accept(this);
v->rhs()->accept(this);
}
void IRVisitor::visit(const MaxTerm* v) {
void IRVisitor::visit(MaxTerm* v) {
if (v->scalar()) {
v->scalar()->accept(this);
}
@ -219,7 +219,7 @@ void IRVisitor::visit(const MaxTerm* v) {
}
}
void IRVisitor::visit(const MinTerm* v) {
void IRVisitor::visit(MinTerm* v) {
if (v->scalar()) {
v->scalar()->accept(this);
}
@ -228,7 +228,7 @@ void IRVisitor::visit(const MinTerm* v) {
}
}
void IRVisitor::visit(const ReduceOp* v) {
void IRVisitor::visit(ReduceOp* v) {
v->body()->accept(this);
for (auto* r : v->reduce_args()) {

View File

@ -54,50 +54,50 @@ class ExternalCall;
class TORCH_API IRVisitor {
public:
virtual ~IRVisitor() = default;
virtual void visit(const Add* v);
virtual void visit(const Sub* v);
virtual void visit(const Mul* v);
virtual void visit(const Div* v);
virtual void visit(const Mod* v);
virtual void visit(const Max* v);
virtual void visit(const Min* v);
virtual void visit(const And* v);
virtual void visit(const Or* v);
virtual void visit(const Xor* v);
virtual void visit(const Lshift* v);
virtual void visit(const Rshift* v);
virtual void visit(const CompareSelect* v);
virtual void visit(Add* v);
virtual void visit(Sub* v);
virtual void visit(Mul* v);
virtual void visit(Div* v);
virtual void visit(Mod* v);
virtual void visit(Max* v);
virtual void visit(Min* v);
virtual void visit(And* v);
virtual void visit(Or* v);
virtual void visit(Xor* v);
virtual void visit(Lshift* v);
virtual void visit(Rshift* v);
virtual void visit(CompareSelect* v);
#define IMM_PRINT_VISIT(Type, Name) virtual void visit(const Name##Imm* v);
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT)
#undef IMM_PRINT_VISIT
virtual void visit(const Cast* v);
virtual void visit(const BitCast* v);
virtual void visit(const Var* v);
virtual void visit(const Buf* v);
virtual void visit(const Ramp* v);
virtual void visit(const Load* v);
virtual void visit(const For* v);
virtual void visit(const Block* v);
virtual void visit(const Store* v);
virtual void visit(const Broadcast* v);
virtual void visit(const IfThenElse* v);
virtual void visit(const Intrinsics* v);
virtual void visit(const Allocate* v);
virtual void visit(const Free* v);
virtual void visit(const Let* v);
virtual void visit(const Cond* v);
virtual void visit(const Term* v);
virtual void visit(const Polynomial* v);
virtual void visit(const RoundOff* v);
virtual void visit(const MaxTerm* v);
virtual void visit(const MinTerm* v);
virtual void visit(const ReduceOp* v);
virtual void visit(const AtomicAdd* v);
virtual void visit(const SyncThreads* v);
virtual void visit(const ExternalCall* v);
virtual void visit(Cast* v);
virtual void visit(BitCast* v);
virtual void visit(Var* v);
virtual void visit(Buf* v);
virtual void visit(Ramp* v);
virtual void visit(Load* v);
virtual void visit(For* v);
virtual void visit(Block* v);
virtual void visit(Store* v);
virtual void visit(Broadcast* v);
virtual void visit(IfThenElse* v);
virtual void visit(Intrinsics* v);
virtual void visit(Allocate* v);
virtual void visit(Free* v);
virtual void visit(Let* v);
virtual void visit(Cond* v);
virtual void visit(Term* v);
virtual void visit(Polynomial* v);
virtual void visit(RoundOff* v);
virtual void visit(MaxTerm* v);
virtual void visit(MinTerm* v);
virtual void visit(ReduceOp* v);
virtual void visit(AtomicAdd* v);
virtual void visit(SyncThreads* v);
virtual void visit(ExternalCall* v);
};
} // namespace tensorexpr

View File

@ -201,7 +201,7 @@ c10::optional<TensorInfo> getTensorInfoJit(torch::jit::Value* v) {
c10::optional<TensorInfo> getTensorInfo(BufHandle b) {
std::vector<int64_t> dims;
for (auto dim : b.dims()) {
auto val = dynamic_cast<const IntImm*>(dim.node());
auto val = dynamic_cast<IntImm*>(dim.node());
if (!val) {
return c10::nullopt;
}
@ -460,7 +460,7 @@ void promoteInputs(std::vector<ExprHandle>& inputs, const int typeConstraints) {
// Find the highest type among the inputs.
ScalarType highType = inputs[0].dtype().scalar_type();
for (const auto input : inputs) {
for (auto input : inputs) {
highType = promoteTypes(highType, input.dtype().scalar_type());
}
@ -503,20 +503,20 @@ ExprHandle demoteOutput(
} // namespace jit
} // namespace torch
static at::ScalarType tensorType(const Buf* b) {
static at::ScalarType tensorType(Buf* b) {
return static_cast<at::ScalarType>(b->dtype().scalar_type());
}
std::vector<int64_t> bufferSizes(const Buf* b) {
std::vector<int64_t> bufferSizes(Buf* b) {
std::vector<int64_t> sizes;
for (size_t i = 0; i < b->ndim(); i++) {
sizes.push_back(dynamic_cast<const IntImm*>(b->dim(i))->value());
sizes.push_back(dynamic_cast<IntImm*>(b->dim(i))->value());
}
return sizes;
}
ExprHandle TensorExprKernel::chunk(
const Buf* b,
Buf* b,
size_t chunkIdx,
int64_t dim,
int64_t chunks,
@ -539,7 +539,7 @@ ExprHandle TensorExprKernel::chunk(
ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) {
if (v->node()->kind() == prim::Constant) {
const auto val = toIValue(v).value();
auto val = toIValue(v).value();
if (val.isDouble()) {
return DoubleImm::make(val.toDouble());
} else if (val.isInt()) {
@ -598,7 +598,7 @@ ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
throw unsupported_dtype();
}
if (v->node()->kind() == prim::Constant) {
const auto val = toIValue(v).value();
auto val = toIValue(v).value();
if (val.isDouble()) {
return val.toDouble();
} else if (val.isInt()) {
@ -626,7 +626,7 @@ ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
std::vector<ExprHandle> TensorExprKernel::sizesFromVaryingShape(
const c10::VaryingShape<int64_t>& shape) {
std::vector<ExprHandle> dims;
for (const auto i : c10::irange(*shape.size())) {
for (auto i : c10::irange(*shape.size())) {
dims.push_back(IntImm::make(*shape[i]));
}
return dims;
@ -728,7 +728,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
case aten::remainder:
case aten::atan2: {
std::vector<std::vector<ExprHandle>> shapes;
for (const auto idx : c10::irange(2)) {
for (auto idx : c10::irange(2)) {
torch::jit::Value* inp = v->node()->input(idx);
shapes.push_back(sizesForValue(inp));
}
@ -739,7 +739,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
case aten::threshold:
case aten::where: {
std::vector<std::vector<ExprHandle>> shapes;
for (const auto idx : c10::irange(3)) {
for (auto idx : c10::irange(3)) {
torch::jit::Value* inp = v->node()->input(idx);
shapes.push_back(sizesForValue(inp));
}
@ -748,7 +748,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
case aten::addcmul: {
std::vector<std::vector<ExprHandle>> shapes;
for (const auto idx : c10::irange(4)) {
for (auto idx : c10::irange(4)) {
torch::jit::Value* inp = v->node()->input(idx);
shapes.push_back(sizesForValue(inp));
}
@ -1129,7 +1129,7 @@ std::pair<ScalarType, std::vector<BufHandle>> processCatList(
nonEmptyInputs.push_back(buf);
}
ScalarType highType = bufInputs[0].dtype().scalar_type();
for (const auto input : bufInputs) {
for (auto input : bufInputs) {
auto maybe_dtype = input.dtype().scalar_type();
highType = promoteTypes(highType, maybe_dtype);
}
@ -1172,11 +1172,11 @@ Tensor* computeCatWoConditionals(
auto gen_code_for_input = [&](const BufHandle& inp,
size_t inp_pos,
const Expr* concat_dim_size,
Expr* concat_dim_size,
const std::vector<ExprHandle>& dims) {
std::vector<Var*> for_vars(dims.size());
std::vector<const Expr*> load_indices(dims.size());
std::vector<const Expr*> store_indices(dims.size());
std::vector<Expr*> load_indices(dims.size());
std::vector<Expr*> store_indices(dims.size());
for (size_t i = 0; i < dims.size(); ++i) {
for_vars[i] = new Var(
"i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt);
@ -1256,8 +1256,7 @@ Tensor* computeCat(
ExprHandle load = promoteToDtype(
tensorOrConstant(nonEmptyInputs[0], newAxes), highType);
size_t offset =
dynamic_cast<const IntImm*>(nonEmptyInputs[0].node()->dim(dim))
->value();
dynamic_cast<IntImm*>(nonEmptyInputs[0].node()->dim(dim))->value();
newAxes[dim] = newAxes[dim] - IntImm::make(offset);
for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) {
@ -1267,8 +1266,7 @@ Tensor* computeCat(
load,
promoteToDtype(tensorOrConstant(input, newAxes), highType));
offset +=
dynamic_cast<const IntImm*>(input.node()->dim(dim))->value();
offset += dynamic_cast<IntImm*>(input.node()->dim(dim))->value();
newAxes[dim] = axes[dim] - IntImm::make(offset);
}
@ -2317,7 +2315,7 @@ Tensor* tensorexpr::computeOperandValue(
*/
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
ExprHandle cur_stride = 1;
std::vector<const Expr*> dims, indices;
std::vector<Expr*> dims, indices;
for (size_t idx = 0; idx < view_dims.size(); idx++) {
dims.push_back(new IntImm(view_dims[idx]));
indices.push_back(axes[idx].node());
@ -2431,7 +2429,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
}
// Return the (lower, upper) loop bounds if they are constants, else nullopt.
c10::optional<std::pair<int64_t, int64_t>> loopBounds(const For* loop) {
c10::optional<std::pair<int64_t, int64_t>> loopBounds(For* loop) {
auto start = IRSimplifier::simplify(loop->start());
auto stop = IRSimplifier::simplify(loop->stop());
if (!start->isConstant() || !stop->isConstant()) {
@ -2803,7 +2801,7 @@ bool denseAndNonOverlapping(
Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
const TensorTypePtr& tt = v->type()->expect<TensorType>();
TORCH_INTERNAL_ASSERT(bufs_.count(v));
const Buf* buf = bufs_.at(v);
Buf* buf = bufs_.at(v);
// No shape info is present in the graph
if (!tt->sizes().concrete_sizes()) {
@ -2813,7 +2811,7 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
}
TORCH_INTERNAL_ASSERT(tt->sizes().concrete_sizes());
const auto sizes = *tt->sizes().concrete_sizes();
auto sizes = *tt->sizes().concrete_sizes();
std::vector<int64_t> default_strides = TensorType::contiguousStridesOf(sizes);
if (!tt->strides().concrete_sizes()) {
return new Tensor(buf, nullptr);
@ -2887,14 +2885,14 @@ void TensorExprKernel::bindConstant(const torch::jit::Value* v) {
auto const_tensor = toIValue(v)->toTensor();
const auto& tt = v->type()->expect<TensorType>();
const auto sizes = *tt->sizes().concrete_sizes();
auto sizes = *tt->sizes().concrete_sizes();
std::vector<ExprHandle> te_sizes;
te_sizes.reserve(sizes.size());
for (auto s : sizes) {
te_sizes.push_back(IntImm::make(s));
}
const Buf* buf = new Buf(
Buf* buf = new Buf(
"const_" + v->debugName(),
ExprHandleVectorToExprVector(te_sizes),
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
@ -2951,7 +2949,7 @@ void TensorExprKernel::compile() {
}
// Move output operands from `bufs_` to `bufOutputs_`
for (const auto& output : graph_->outputs()) {
for (auto& output : graph_->outputs()) {
if (!bufs_.count(output)) {
throw malformed_input("cannot find output Tensor");
}
@ -3046,7 +3044,7 @@ std::vector<CodeGen::CallArg> TensorExprKernel::prepareRunArgs(
std::vector<CodeGen::CallArg> runArgs;
runArgs.reserve(inputs.size() + bufOutputs_.size());
for (const auto& input : inputs) {
for (auto& input : inputs) {
if (input.isInt()) {
runArgs.emplace_back(input.toInt());
} else if (input.isDouble()) {

View File

@ -19,7 +19,7 @@ template <typename T>
inline std::vector<int64_t> bufferSizes(const T& t) {
std::vector<int64_t> sizes;
for (size_t i = 0; i < t->ndim(); i++) {
sizes.push_back(dynamic_cast<const IntImm*>(t->dim(i))->value());
sizes.push_back(dynamic_cast<IntImm*>(t->dim(i))->value());
}
return sizes;
}
@ -105,7 +105,7 @@ inline std::string getArgValueName(const ArgValue& a) {
template <class T>
std::vector<T> convertVecArgValue(const std::vector<ArgValue>& v) {
std::vector<T> res;
for (const auto& x : v) {
for (auto& x : v) {
auto val = c10::get_if<T>(&x);
if (val) {
res.push_back(*val);
@ -132,7 +132,7 @@ TORCH_API Tensor* computeOperandValue(
class TORCH_API TensorExprKernel {
struct ConstantDescr {
const Buf* buf;
Buf* buf;
void* ptr;
};
@ -196,7 +196,7 @@ class TORCH_API TensorExprKernel {
std::vector<std::vector<ExprHandle>> shapes);
ExprHandle chunk(
const Buf* b,
Buf* b,
size_t chunkIdx,
int64_t dim,
int64_t chunks,
@ -260,8 +260,8 @@ class TORCH_API TensorExprKernel {
std::vector<std::vector<int64_t>> tensorOutputSizes_;
std::vector<std::vector<int64_t>> tensorOutputStrides_;
std::vector<UnpackedTensorOptions> tensorOutputTensorOptions_;
std::unordered_set<const Buf*> bufOutputs_;
std::unordered_map<const torch::jit::Value*, const Buf*> bufs_;
std::unordered_set<Buf*> bufOutputs_;
std::unordered_map<const torch::jit::Value*, Buf*> bufs_;
std::unordered_map<const torch::jit::Value*, VarHandle> scalars_;
std::unordered_map<const torch::jit::Value*, std::string> input_name_map_;
std::unique_ptr<CodeGen> codegen_;

View File

@ -169,8 +169,8 @@ class LLVMCodeGenImpl : public IRVisitor {
std::unordered_map<const Var*, int> varToArg_;
std::unordered_map<const Var*, llvm::Value*> varToVal_;
std::unordered_map<const Block*, std::vector<const Var*>> scopeToVar_;
const Block* scope_;
std::unordered_map<Block*, std::vector<Var*>> scopeToVar_;
Block* scope_;
std::string llvmCode_;
std::string asmCode_;
@ -195,13 +195,13 @@ class LLVMCodeGenImpl : public IRVisitor {
Arity arity,
int lanes);
llvm::Value* varToValue(const Var* var);
llvm::Value* varToValue(Var* var);
void replaceVarMapping(
const std::vector<const Var*>& vars,
const std::vector<Var*>& vars,
const std::vector<llvm::Value*>& vals);
llvm::Value* packFuncArgs(const std::vector<llvm::Value*>& func_args);
std::vector<llvm::Value*> unpackFuncArgs(llvm::Value* packed, int arg_count);
void processParallelFor(const For* v);
void processParallelFor(For* v);
public:
LLVMCodeGenImpl(
@ -216,42 +216,42 @@ class LLVMCodeGenImpl : public IRVisitor {
llvm::JITTargetAddress getKernelAddress() const;
void visit(const Add* v) override;
void visit(const Sub* v) override;
void visit(const Mul* v) override;
void visit(const Div* v) override;
void visit(const Mod* v) override;
void visit(const Max* v) override;
void visit(const Min* v) override;
void visit(const And* v) override;
void visit(const Or* v) override;
void visit(const Xor* v) override;
void visit(const Lshift* v) override;
void visit(const Rshift* v) override;
void visit(const CompareSelect* v) override;
void visit(Add* v) override;
void visit(Sub* v) override;
void visit(Mul* v) override;
void visit(Div* v) override;
void visit(Mod* v) override;
void visit(Max* v) override;
void visit(Min* v) override;
void visit(And* v) override;
void visit(Or* v) override;
void visit(Xor* v) override;
void visit(Lshift* v) override;
void visit(Rshift* v) override;
void visit(CompareSelect* v) override;
#define IMM_VISIT_DECLARE(_1, Name) void visit(const Name##Imm* v) override;
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT_DECLARE);
#undef IMM_VISIT_DECLARE
void visit(const Cast* v) override;
void visit(const BitCast* v) override;
void visit(const Var* v) override;
void visit(const Ramp* v) override;
void visit(const Load* v) override;
void visit(const For* v) override;
void visit(const Block* v) override;
void visit(const Store* v) override;
void visit(const Broadcast* v) override;
void visit(const IfThenElse* v) override;
void visit(const Intrinsics* v) override;
void visit(const Allocate* v) override;
void visit(const Free* v) override;
void visit(const Let* v) override;
void visit(const Cond* v) override;
void visit(const ExternalCall* v) override;
void visit(Cast* v) override;
void visit(BitCast* v) override;
void visit(Var* v) override;
void visit(Ramp* v) override;
void visit(Load* v) override;
void visit(For* v) override;
void visit(Block* v) override;
void visit(Store* v) override;
void visit(Broadcast* v) override;
void visit(IfThenElse* v) override;
void visit(Intrinsics* v) override;
void visit(Allocate* v) override;
void visit(Free* v) override;
void visit(Let* v) override;
void visit(Cond* v) override;
void visit(ExternalCall* v) override;
void emitIsNan(const Intrinsics* v);
void emitIsNan(Intrinsics* v);
llvm::Value* emitUnmaskedLoad(llvm::Value* addr, llvm::Value* idx);
llvm::Value* emitMaskedLoad(
@ -312,7 +312,7 @@ void LLVMCodeGen::call_raw(const std::vector<void*>& args) {
}
void LLVMCodeGen::call(const std::vector<CallArg>& args) {
const auto& buf_args = buffer_args();
auto& buf_args = buffer_args();
if (args.size() != buf_args.size()) {
throw malformed_input("wrong number of args in call");
}
@ -403,7 +403,7 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
// Emit prototype and bind argument Vars to parameter indices.
llvm::Type* retTy = dtypeToLLVM(dtype);
std::vector<llvm::Type*> params;
for (const auto i : c10::irange(args.size())) {
for (auto i : c10::irange(args.size())) {
auto const& arg = args[i];
if (arg.isVar()) {
params.push_back(dtypeToLLVM(arg.dtype()));
@ -418,7 +418,7 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
fn_->addAttribute(
llvm::AttributeList::AttrIndex::FunctionIndex,
llvm::Attribute::AlwaysInline);
for (const auto i : c10::irange(args.size())) {
for (auto i : c10::irange(args.size())) {
if (!args[i].isVar()) {
fn_->addParamAttr(i, llvm::Attribute::NoAlias);
}
@ -465,7 +465,7 @@ void LLVMCodeGenImpl::emitWrapper(const std::vector<llvm::Type*>& params) {
auto wrapBB = llvm::BasicBlock::Create(getContext(), "wrapBB", wrapper);
irb_.SetInsertPoint(wrapBB);
llvm::SmallVector<llvm::Value*, 6> wrappedArgs;
for (const auto i : c10::irange(params.size())) {
for (auto i : c10::irange(params.size())) {
auto argp = irb_.CreateGEP(
wrapper->arg_begin(), llvm::ConstantInt::getSigned(IntTy_, i));
if (params[i]->isPointerTy()) {
@ -484,7 +484,7 @@ void LLVMCodeGenImpl::emitWrapper(const std::vector<llvm::Type*>& params) {
class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander {
private:
const Expr* mutate(const Intrinsics* v) {
Expr* mutate(Intrinsics* v) {
if (v->op_type() == kTanh) {
ScalarType stype = v->dtype().scalar_type();
if (stype == ScalarType::Float) {
@ -570,7 +570,7 @@ void LLVMCodeGenImpl::emitKernel(
// TODO: The binary ops are copypasta.
void LLVMCodeGenImpl::visit(const Add* v) {
void LLVMCodeGenImpl::visit(Add* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -588,7 +588,7 @@ void LLVMCodeGenImpl::visit(const Add* v) {
}
}
void LLVMCodeGenImpl::visit(const Sub* v) {
void LLVMCodeGenImpl::visit(Sub* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -606,7 +606,7 @@ void LLVMCodeGenImpl::visit(const Sub* v) {
}
}
void LLVMCodeGenImpl::visit(const Mul* v) {
void LLVMCodeGenImpl::visit(Mul* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -624,7 +624,7 @@ void LLVMCodeGenImpl::visit(const Mul* v) {
}
}
void LLVMCodeGenImpl::visit(const Div* v) {
void LLVMCodeGenImpl::visit(Div* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -642,7 +642,7 @@ void LLVMCodeGenImpl::visit(const Div* v) {
}
}
void LLVMCodeGenImpl::visit(const And* v) {
void LLVMCodeGenImpl::visit(And* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -657,7 +657,7 @@ void LLVMCodeGenImpl::visit(const And* v) {
}
}
void LLVMCodeGenImpl::visit(const Or* v) {
void LLVMCodeGenImpl::visit(Or* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -672,7 +672,7 @@ void LLVMCodeGenImpl::visit(const Or* v) {
}
}
void LLVMCodeGenImpl::visit(const Xor* v) {
void LLVMCodeGenImpl::visit(Xor* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -687,7 +687,7 @@ void LLVMCodeGenImpl::visit(const Xor* v) {
}
}
void LLVMCodeGenImpl::visit(const Lshift* v) {
void LLVMCodeGenImpl::visit(Lshift* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -702,7 +702,7 @@ void LLVMCodeGenImpl::visit(const Lshift* v) {
}
}
void LLVMCodeGenImpl::visit(const Rshift* v) {
void LLVMCodeGenImpl::visit(Rshift* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -721,7 +721,7 @@ void LLVMCodeGenImpl::visit(const Rshift* v) {
}
}
void LLVMCodeGenImpl::visit(const Mod* v) {
void LLVMCodeGenImpl::visit(Mod* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@ -736,7 +736,7 @@ void LLVMCodeGenImpl::visit(const Mod* v) {
}
}
void LLVMCodeGenImpl::visit(const Max* v) {
void LLVMCodeGenImpl::visit(Max* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
v->rhs()->accept(this);
@ -759,7 +759,7 @@ void LLVMCodeGenImpl::visit(const Max* v) {
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs));
}
void LLVMCodeGenImpl::visit(const Min* v) {
void LLVMCodeGenImpl::visit(Min* v) {
v->lhs()->accept(this);
auto lhs = this->value_;
v->rhs()->accept(this);
@ -781,7 +781,7 @@ void LLVMCodeGenImpl::visit(const Min* v) {
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs));
}
void LLVMCodeGenImpl::visit(const CompareSelect* v) {
void LLVMCodeGenImpl::visit(CompareSelect* v) {
auto genUnbiased = [this, v]() -> llvm::Value* {
v->lhs()->accept(this);
auto lhs = this->value_;
@ -906,7 +906,7 @@ llvm::Type* llvmTypeToVec(llvm::Type* type, int lanes) {
}
}
void LLVMCodeGenImpl::visit(const Cast* v) {
void LLVMCodeGenImpl::visit(Cast* v) {
v->src_value()->accept(this);
llvm::Type* dstType =
@ -978,7 +978,7 @@ void LLVMCodeGenImpl::visit(const Cast* v) {
}
}
void LLVMCodeGenImpl::visit(const BitCast* v) {
void LLVMCodeGenImpl::visit(BitCast* v) {
v->src_value()->accept(this);
llvm::Type* dstType = dtypeToLLVM(v->dtype());
@ -997,11 +997,11 @@ void LLVMCodeGenImpl::visit(const BitCast* v) {
value_ = irb_.CreateBitOrPointerCast(value_, dstType);
}
void LLVMCodeGenImpl::visit(const Var* v) {
void LLVMCodeGenImpl::visit(Var* v) {
value_ = varToValue(v);
}
llvm::Value* LLVMCodeGenImpl::varToValue(const Var* v) {
llvm::Value* LLVMCodeGenImpl::varToValue(Var* v) {
// It is possible for v to be in both varToVal_ and varToArgs.
// In that case, varToVal_ takes precedence.
if (varToVal_.count(v)) {
@ -1015,11 +1015,11 @@ llvm::Value* LLVMCodeGenImpl::varToValue(const Var* v) {
}
void LLVMCodeGenImpl::replaceVarMapping(
const std::vector<const Var*>& vars,
const std::vector<Var*>& vars,
const std::vector<llvm::Value*>& vals) {
TORCH_CHECK(vars.size() == vals.size());
for (const auto i : c10::irange(vars.size())) {
const Var* var = vars[i];
for (auto i : c10::irange(vars.size())) {
Var* var = vars[i];
llvm::Value* val = vals[i];
if (val) {
varToVal_[var] = val;
@ -1029,7 +1029,7 @@ void LLVMCodeGenImpl::replaceVarMapping(
}
}
void LLVMCodeGenImpl::visit(const Ramp* v) {
void LLVMCodeGenImpl::visit(Ramp* v) {
v->base()->accept(this);
auto base = this->value_;
v->stride()->accept(this);
@ -1105,7 +1105,7 @@ llvm::Value* LLVMCodeGenImpl::emitMaskedLoad(
return phi;
}
void LLVMCodeGenImpl::visit(const Load* v) {
void LLVMCodeGenImpl::visit(Load* v) {
if (v->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
@ -1134,9 +1134,9 @@ void LLVMCodeGenImpl::visit(const Load* v) {
bool unmasked_load = true;
// Handle the case where the load is contiguous and unmasked efficiently
auto* idx_ramp = dynamic_cast<const Ramp*>(v->flat_index());
auto* idx_ramp = dynamic_cast<Ramp*>(v->flat_index());
if (idx_ramp) {
auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
auto* stride_imm = dynamic_cast<IntImm*>(idx_ramp->stride());
if (stride_imm && stride_imm->value() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
@ -1181,14 +1181,14 @@ llvm::Value* LLVMCodeGenImpl::packFuncArgs(
return NullPtr;
}
std::vector<llvm::Type*> arg_types(func_args.size());
for (const auto i : c10::irange(func_args.size())) {
for (auto i : c10::irange(func_args.size())) {
arg_types[i] = func_args[i]->getType();
}
llvm::StructType* packed_type = llvm::StructType::create(arg_types);
llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
llvm::Value* one = llvm::ConstantInt::get(IntTy_, 1);
llvm::Value* packed = irb_.CreateAlloca(packed_type, one);
for (const auto i : c10::irange(func_args.size())) {
for (auto i : c10::irange(func_args.size())) {
llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
irb_.CreateStore(func_args[i], dst_ptr);
@ -1203,7 +1203,7 @@ std::vector<llvm::Value*> LLVMCodeGenImpl::unpackFuncArgs(
// TODO: extract arg_count from packed.
std::vector<llvm::Value*> func_args(arg_count);
llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
for (const auto i : c10::irange(arg_count)) {
for (auto i : c10::irange(arg_count)) {
llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
func_args[i] = irb_.CreateLoad(dst_ptr);
@ -1215,7 +1215,7 @@ std::vector<llvm::Value*> LLVMCodeGenImpl::unpackFuncArgs(
// * Move the body into its own closure.
// * Identify var across the boundary into arguments and forward them.
// * Send the closure and range to the dispatcher for execution.
void LLVMCodeGenImpl::processParallelFor(const For* v) {
void LLVMCodeGenImpl::processParallelFor(For* v) {
// Create "start" and "stop" values.
v->start()->accept(this);
auto start = this->value_;
@ -1223,7 +1223,7 @@ void LLVMCodeGenImpl::processParallelFor(const For* v) {
auto stop = this->value_;
// The Vars that need to be forward in the body closure.
std::vector<const Var*> body_arg_vars;
std::vector<Var*> body_arg_vars;
// Corresponding Value* that was used in the old body for the caller.
std::vector<llvm::Value*> body_caller_vals;
// Corresponding Value* that will be used in the new body closure.
@ -1232,7 +1232,7 @@ void LLVMCodeGenImpl::processParallelFor(const For* v) {
// Identify the Var* used in the body, and generated outside.
VarFinder var_finder;
v->body()->accept(&var_finder);
const auto& vars = var_finder.vars();
auto& vars = var_finder.vars();
for (auto& var : vars) {
if (llvm::Value* value = varToValue(var)) {
body_arg_vars.push_back(var);
@ -1292,7 +1292,7 @@ void LLVMCodeGenImpl::processParallelFor(const For* v) {
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
void LLVMCodeGenImpl::visit(const For* v) {
void LLVMCodeGenImpl::visit(For* v) {
if (v->is_parallel()) {
processParallelFor(v);
return;
@ -1347,8 +1347,8 @@ void LLVMCodeGenImpl::visit(const For* v) {
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
void LLVMCodeGenImpl::visit(const Block* v) {
const Block* last = scope_;
void LLVMCodeGenImpl::visit(Block* v) {
Block* last = scope_;
scope_ = v;
for (Stmt* s : *v) {
@ -1359,7 +1359,7 @@ void LLVMCodeGenImpl::visit(const Block* v) {
auto it = scopeToVar_.find(v);
if (it != scopeToVar_.end()) {
for (const Var* e : it->second) {
for (Var* e : it->second) {
if (varToVal_.erase(e) != 1) {
throw std::runtime_error("erasing var that doesn't exist");
}
@ -1398,7 +1398,7 @@ void LLVMCodeGenImpl::emitMaskedStore(
irb_.SetInsertPoint(tailblock);
}
void LLVMCodeGenImpl::visit(const Store* v) {
void LLVMCodeGenImpl::visit(Store* v) {
if (v->value()->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
@ -1419,9 +1419,9 @@ void LLVMCodeGenImpl::visit(const Store* v) {
auto val = this->value_;
// Handle the case where the store is contiguous and unmasked efficiently
auto* idx_ramp = dynamic_cast<const Ramp*>(v->flat_index());
auto* idx_ramp = dynamic_cast<Ramp*>(v->flat_index());
if (idx_ramp) {
auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
auto* stride_imm = dynamic_cast<IntImm*>(idx_ramp->stride());
if (stride_imm && stride_imm->value() == 1) {
idx_ramp->base()->accept(this);
auto first_idx = value_;
@ -1453,13 +1453,13 @@ void LLVMCodeGenImpl::visit(const Store* v) {
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
void LLVMCodeGenImpl::visit(const Broadcast* v) {
void LLVMCodeGenImpl::visit(Broadcast* v) {
v->value()->accept(this);
int lanes = v->lanes();
value_ = irb_.CreateVectorSplat(lanes, value_);
}
void LLVMCodeGenImpl::visit(const IfThenElse* v) {
void LLVMCodeGenImpl::visit(IfThenElse* v) {
v->condition()->accept(this);
llvm::Value* condition = value_;
llvm::Value* c = irb_.CreateICmpNE(
@ -1509,7 +1509,7 @@ llvm::Value* LLVMCodeGenImpl::toVec(llvm::Value* v, int lanes) {
}
}
void LLVMCodeGenImpl::emitIsNan(const Intrinsics* v) {
void LLVMCodeGenImpl::emitIsNan(Intrinsics* v) {
v->param(0)->accept(this);
llvm::Type* dstType = dtypeToLLVM(v->dtype());
if (!v->param(0)->dtype().is_floating_point()) {
@ -1583,7 +1583,7 @@ LLVMCodeGenImpl::SimdCallee LLVMCodeGenImpl::getSimdFunction(
return SimdCallee{callee.getFunctionType(), callee.getCallee(), useSimd};
}
void LLVMCodeGenImpl::visit(const Intrinsics* v) {
void LLVMCodeGenImpl::visit(Intrinsics* v) {
llvm::FunctionType* call_ty = nullptr;
llvm::Value* call_fn = nullptr;
bool call_simd_sleef = false;
@ -1772,7 +1772,7 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) {
}
}
void LLVMCodeGenImpl::visit(const ExternalCall* v) {
void LLVMCodeGenImpl::visit(ExternalCall* v) {
constexpr int max_buffers = 10;
constexpr int max_dimensions = 40;
@ -1783,7 +1783,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
// Prepare a vector of bufs that we need to pass to the external function.
// This vector is the output buf followed by the buf_args.
std::vector<const Buf*> bufs(v->buf_args());
std::vector<Buf*> bufs(v->buf_args());
bufs.insert(bufs.begin(), v->buf());
int64_t bufs_num = bufs.size();
@ -1792,7 +1792,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
// Count the size of dims array - it consists of dimension of all bufs
// concatenated together.
int64_t dims_num = 0;
for (const Buf* b : bufs) {
for (Buf* b : bufs) {
dims_num += b->dims().size();
}
@ -1809,7 +1809,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
int i = 0;
int dim_idx = 0;
for (const Buf* b : bufs) {
for (Buf* b : bufs) {
// Store value for buf pointer
auto gep = irb_.CreateInBoundsGEP(
buf_ptrs, {llvm::ConstantInt::getSigned(IntTy_, i)});
@ -1832,7 +1832,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
llvm::ConstantInt::getSigned(LongTy_, b->dims().size()), gep);
// Store dims of the buf
for (const auto dim : c10::irange(b->dims().size())) {
for (auto dim : c10::irange(b->dims().size())) {
gep = irb_.CreateInBoundsGEP(
buf_dims, {llvm::ConstantInt::getSigned(IntTy_, dim_idx)});
b->dims()[dim]->accept(this);
@ -1845,7 +1845,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
}
i = 0;
for (const Expr* arg : v->args()) {
for (Expr* arg : v->args()) {
auto gep = irb_.CreateInBoundsGEP(
extra_args, {llvm::ConstantInt::getSigned(IntTy_, i)});
arg->accept(this);
@ -1886,10 +1886,10 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
void LLVMCodeGenImpl::visit(const Allocate* v) {
void LLVMCodeGenImpl::visit(Allocate* v) {
llvm::Value* size =
llvm::ConstantInt::getSigned(LongTy_, v->dtype().byte_size());
for (const Expr* e : v->dims()) {
for (Expr* e : v->dims()) {
e->accept(this);
size = irb_.CreateMul(size, irb_.CreateZExt(value_, LongTy_));
}
@ -1918,7 +1918,7 @@ void LLVMCodeGenImpl::visit(const Allocate* v) {
varToVal_[v->buffer_var()] = malloc;
}
void LLVMCodeGenImpl::visit(const Free* v) {
void LLVMCodeGenImpl::visit(Free* v) {
value_ = llvm::ConstantInt::get(IntTy_, 0);
llvm::Value* ptr = varToVal_.at(v->buffer_var());
if (!llvm::isa<llvm::AllocaInst>(ptr)) {
@ -1926,7 +1926,7 @@ void LLVMCodeGenImpl::visit(const Free* v) {
}
}
void LLVMCodeGenImpl::visit(const Let* v) {
void LLVMCodeGenImpl::visit(Let* v) {
v->value()->accept(this);
if (!varToVal_.count(v->var())) {
varToVal_.emplace(v->var(), value_);
@ -1936,7 +1936,7 @@ void LLVMCodeGenImpl::visit(const Let* v) {
}
}
void LLVMCodeGenImpl::visit(const Cond* v) {
void LLVMCodeGenImpl::visit(Cond* v) {
// Even if true_stmt and false_stmt are nullptr,
// in case condition is a function call with side effect,
// we still evaluate it.

View File

@ -113,7 +113,7 @@ static void registerIntrinsics(
}
assertSuccess(JD.define(absoluteSymbols(symbols)));
for (const auto& kv : getNNCFunctionRegistry()) {
for (auto& kv : getNNCFunctionRegistry()) {
assertSuccess(
JD.define(absoluteSymbols({entry(kv.first.c_str(), kv.second)})));
}

File diff suppressed because it is too large Load Diff

View File

@ -34,7 +34,7 @@ class TORCH_API LoopNest {
// A constructor for building a LoopNest from an Stmt and a list of output
// buffers.
LoopNest(Stmt* stmt, std::unordered_set<const Buf*> output_bufs);
LoopNest(Stmt* stmt, std::unordered_set<Buf*> output_bufs);
// A constructor for building a LoopNest from another loopnest. It clones the
// other loopnest's stmt.
@ -45,10 +45,10 @@ class TORCH_API LoopNest {
}
std::vector<For*> getLoopStmtsFor(Tensor*) const;
std::vector<For*> getLoopStmtsFor(const Buf*) const;
std::vector<For*> getLoopStmtsFor(Buf*) const;
std::vector<For*> getLoopStmtsFor(Stmt*) const;
Stmt* getLoopBodyFor(Tensor*) const;
Stmt* getLoopBodyFor(const Buf*) const;
Stmt* getLoopBodyFor(Buf*) const;
// Returns the For stmt indexed by 'indices' in the 'root' For stmt.
//'indices' indicates the path to the returned loop from 'root' in AST, e.g.,
@ -71,14 +71,14 @@ class TORCH_API LoopNest {
For* getLoopAt(For* root, const std::vector<int>& indices) const;
// Returns the For stmt that is immediately enclosing the given stmt.
static For* getParentLoop(const Stmt* st);
static For* getParentLoop(Stmt* st);
// Returns the list of For stmts corresponding to the loopnest that is
// enclosing the given stmt.
static std::vector<For*> getEnclosingLoopNest(const Stmt* st);
static std::vector<For*> getEnclosingLoopNest(Stmt* st);
// Returns a list of all Stmts that write to the given buf.
std::vector<const Stmt*> getAllWritesToBuf(const Buf*) const;
std::vector<Stmt*> getAllWritesToBuf(Buf*) const;
// The following methods return the For loops that contain writes to
// the given buf.
@ -98,18 +98,18 @@ class TORCH_API LoopNest {
// to buf.
// For the above example:
// getAllInnermostLoopsWritingToBuf(a) => {j1, k2, j3}
std::vector<For*> getAllInnermostLoopsWritingToBuf(const Buf*) const;
std::vector<For*> getAllInnermostLoopsWritingToBuf(Buf*) const;
// Returns a list of For loopnests which contain a Stmt that writes to
// the given buf. Each loopnest here is a vector For loops.
// For the above example:
// getAllLoopNestsWritingToBuf(a) => {{i1,j1}, {i2,j2,k2}, {i2,j3}}
std::vector<std::vector<For*>> getAllLoopNestsWritingToBuf(const Buf*) const;
std::vector<std::vector<For*>> getAllLoopNestsWritingToBuf(Buf*) const;
Stmt* simplify();
bool computeInline(Stmt* s);
bool computeInline(const Buf* b);
bool computeInline(Buf* b);
void inlineIntermediateBufs(bool allow_duplicated_work);
// Optimizes conditionals.
@ -463,12 +463,12 @@ class TORCH_API LoopNest {
static void sliceTail(For* f, int factor, For** head, For** tail);
static void sliceTail(For* f, int factor);
using AccessResult = std::pair<const Buf*, Stmt*>;
using AccessResult = std::pair<Buf*, Stmt*>;
// Insert a cache for the consumer's usages of the buffer produced in
// consumer, and redirect reads and writes in the consumer to that cache.
// Returns a pair of the new cache buffer, and the new rewritten consumer.
static AccessResult cacheAccesses(
const Buf* producer,
Buf* producer,
const std::string& name,
Stmt* consumer);
@ -535,8 +535,8 @@ class TORCH_API LoopNest {
void eliminateDeadStores();
void prepareForCodegen();
const std::unordered_set<const Buf*> getInputBufs() const;
const std::unordered_set<const Buf*> getOutputBufs() const {
const std::unordered_set<Buf*> getInputBufs() const;
const std::unordered_set<Buf*> getOutputBufs() const {
return output_bufs_;
}
@ -545,11 +545,11 @@ class TORCH_API LoopNest {
const std::vector<Tensor*>& output_tensors,
const std::vector<Tensor*>& tensors_to_compute);
Stmt* insertAllocFree(Stmt* stmt);
const std::unordered_set<const Buf*> getIntermediateBufs() const;
const std::unordered_set<Buf*> getIntermediateBufs() const;
Stmt* root_stmt_;
std::unordered_set<const Buf*> output_bufs_;
std::unordered_set<Buf*> output_bufs_;
};
TORCH_API Stmt* FlattenIndexes(Stmt* s);
@ -568,8 +568,8 @@ struct BufLoadOrStoreUse {
* in the vectors reflects the order in which the uses appear in the given
* statement.
*/
std::unordered_map<const Buf*, std::vector<BufLoadOrStoreUse>>
findLoadOrStoreUses(Stmt* s);
std::unordered_map<Buf*, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
Stmt* s);
} // namespace tensorexpr
} // namespace jit

View File

@ -59,15 +59,15 @@ void getDependentsChain(
// AccessInfo
std::vector<const Expr*> AccessInfo::getIndices() const {
std::vector<const Expr*> indices;
std::vector<Expr*> AccessInfo::getIndices() const {
std::vector<Expr*> indices;
if (expr_) {
if (auto* load = dynamic_cast<const Load*>(expr_)) {
if (auto* load = dynamic_cast<Load*>(expr_)) {
indices = load->indices();
}
} else {
if (auto* store = dynamic_cast<const Store*>(stmt_)) {
if (auto* store = dynamic_cast<Store*>(stmt_)) {
indices = store->indices();
}
}
@ -255,8 +255,8 @@ MemDependencyChecker::MemDependencyChecker() {
}
MemDependencyChecker::MemDependencyChecker(
const std::unordered_set<const Buf*>& inputs,
const std::unordered_set<const Buf*>& outputs) {
const std::unordered_set<Buf*>& inputs,
const std::unordered_set<Buf*>& outputs) {
for (auto* s : inputs) {
inputs_[s] = nullptr;
}
@ -320,15 +320,15 @@ DependencySet MemDependencyChecker::getAllWriteDependencies(
return writes;
}
bool MemDependencyChecker::dependsDirectly(const Expr* A, const Stmt* B) {
bool MemDependencyChecker::dependsDirectly(Expr* A, Stmt* B) {
return dependsDirectlyHelper(A, B);
}
bool MemDependencyChecker::dependsDirectly(const Stmt* A, const Stmt* B) {
bool MemDependencyChecker::dependsDirectly(Stmt* A, Stmt* B) {
return dependsDirectlyHelper(A, B);
}
bool MemDependencyChecker::dependsDirectly(const Buf* O, const Stmt* B) {
bool MemDependencyChecker::dependsDirectly(Buf* O, Stmt* B) {
auto outputAccess = output(O);
auto bWrites = getAllWritesWithin(B);
@ -341,7 +341,7 @@ bool MemDependencyChecker::dependsDirectly(const Buf* O, const Stmt* B) {
return false;
}
bool MemDependencyChecker::dependsDirectly(const Stmt* A, const Buf* I) {
bool MemDependencyChecker::dependsDirectly(Stmt* A, Buf* I) {
auto aReads = getAllReadsWithin(A);
auto inputAccess = input(I);
@ -354,7 +354,7 @@ bool MemDependencyChecker::dependsDirectly(const Stmt* A, const Buf* I) {
return false;
}
bool MemDependencyChecker::dependsDirectly(const Expr* A, const Buf* I) {
bool MemDependencyChecker::dependsDirectly(Expr* A, Buf* I) {
auto aReads = getAllReadsWithin(A);
auto inputAccess = input(I);
@ -373,15 +373,15 @@ bool MemDependencyChecker::dependsDirectly(
return A->hasDependency(B) && B->isWrite();
}
bool MemDependencyChecker::dependsIndirectly(const Expr* A, const Stmt* B) {
bool MemDependencyChecker::dependsIndirectly(Expr* A, Stmt* B) {
return dependsIndirectlyHelper(A, B);
}
bool MemDependencyChecker::dependsIndirectly(const Stmt* A, const Stmt* B) {
bool MemDependencyChecker::dependsIndirectly(Stmt* A, Stmt* B) {
return dependsIndirectlyHelper(A, B);
}
bool MemDependencyChecker::dependsIndirectly(const Buf* O, const Stmt* B) {
bool MemDependencyChecker::dependsIndirectly(Buf* O, Stmt* B) {
auto outputAccess = output(O);
DependencySet dependencies;
@ -397,7 +397,7 @@ bool MemDependencyChecker::dependsIndirectly(const Buf* O, const Stmt* B) {
return false;
}
bool MemDependencyChecker::dependsIndirectly(const Stmt* A, const Buf* I) {
bool MemDependencyChecker::dependsIndirectly(Stmt* A, Buf* I) {
auto aReads = getAllReadsWithin(A);
auto inputAccess = input(I);
@ -406,7 +406,7 @@ bool MemDependencyChecker::dependsIndirectly(const Stmt* A, const Buf* I) {
return aDeps.count(inputAccess) != 0;
}
bool MemDependencyChecker::dependsIndirectly(const Expr* A, const Buf* I) {
bool MemDependencyChecker::dependsIndirectly(Expr* A, Buf* I) {
auto aReads = getAllReadsWithin(A);
auto inputAccess = input(I);
@ -415,7 +415,7 @@ bool MemDependencyChecker::dependsIndirectly(const Expr* A, const Buf* I) {
return aDeps.count(inputAccess) != 0;
}
bool MemDependencyChecker::dependsIndirectly(const Buf* O, const Buf* I) {
bool MemDependencyChecker::dependsIndirectly(Buf* O, Buf* I) {
auto outputAccess = output(O);
auto inputAccess = input(I);
@ -438,8 +438,7 @@ bool MemDependencyChecker::dependsIndirectly(
return true;
}
std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
const Stmt* A) const {
std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(Stmt* A) const {
auto bound = stmtToAccess_.equal_range(A);
for (auto it = bound.first; it != bound.second; ++it) {
if (it->second->expr() == nullptr) {
@ -449,8 +448,7 @@ std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
return nullptr;
}
std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
const Expr* A) const {
std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(Expr* A) const {
// TODO exprs can have multiple accesses... we're returning the first but that
// isn't great. Can't do much here.
auto bound = exprToAccess_.equal_range(A);
@ -462,7 +460,7 @@ std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
}
std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
accessesWithin(const Stmt* A) const {
accessesWithin(Stmt* A) const {
auto it = scopeToAccesses_.find(A);
if (it != scopeToAccesses_.end()) {
return std::unordered_set<std::shared_ptr<AccessInfo>>(
@ -478,11 +476,11 @@ std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
}
std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
accessesWithin(const Expr* A) const {
accessesWithin(Expr* A) const {
return {accessFor(A)};
}
std::shared_ptr<AccessInfo> MemDependencyChecker::input(const Buf* b) const {
std::shared_ptr<AccessInfo> MemDependencyChecker::input(Buf* b) const {
auto it = inputs_.find(b);
if (it == inputs_.end()) {
return nullptr;
@ -490,7 +488,7 @@ std::shared_ptr<AccessInfo> MemDependencyChecker::input(const Buf* b) const {
return it->second;
}
std::shared_ptr<AccessInfo> MemDependencyChecker::output(const Buf* b) const {
std::shared_ptr<AccessInfo> MemDependencyChecker::output(Buf* b) const {
auto it = outputs_.find(b);
if (it == outputs_.end()) {
return nullptr;
@ -500,18 +498,18 @@ std::shared_ptr<AccessInfo> MemDependencyChecker::output(const Buf* b) const {
// Node visitors:
void MemDependencyChecker::visit(const Store* v) {
const Stmt* last = lastStmt_;
void MemDependencyChecker::visit(Store* v) {
Stmt* last = lastStmt_;
lastStmt_ = v;
v->value()->accept(this);
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
ind->accept(this);
}
lastStmt_ = last;
// Create a new AccessInfo for the store.
const Var* var = v->buf()->base_handle();
Var* var = v->buf()->base_handle();
auto info = std::make_shared<AccessInfo>(
nextAccess_++, AccessType::Store, v, var, getIndicesBounds(v->indices()));
@ -532,19 +530,19 @@ void MemDependencyChecker::visit(const Store* v) {
currentScope_->accesses_.push_back(info);
}
void MemDependencyChecker::visit(const Load* v) {
void MemDependencyChecker::visit(Load* v) {
// Create a temporary scope to hold any loads that occur within the indices of
// this load.
auto indicesScope =
std::make_shared<Scope>(currentScope_->block, currentScope_);
currentScope_ = indicesScope;
for (const Expr* ind : v->indices()) {
for (Expr* ind : v->indices()) {
ind->accept(this);
}
// Create a new AccessInfo for the load.
const Var* var = v->buf()->base_handle();
Var* var = v->buf()->base_handle();
auto load = std::make_shared<AccessInfo>(
nextAccess_++,
AccessType::Load,
@ -584,24 +582,24 @@ void MemDependencyChecker::visit(const Load* v) {
bool executionSafetyCheck(
const std::shared_ptr<AccessInfo>& info,
const std::shared_ptr<AccessInfo>& other,
const std::vector<const Expr*>& aStrides,
const std::vector<const Expr*>& oStrides,
const std::vector<Expr*>& aStrides,
const std::vector<Expr*>& oStrides,
bool parallelized) {
if (aStrides.empty() || oStrides.empty()) {
return false;
}
TORCH_INTERNAL_ASSERT(info->bounds().size() == other->bounds().size());
for (size_t b = 0; b < info->bounds().size(); ++b) {
const Expr* aIndexStride = aStrides[b];
const Expr* oIndexStride = oStrides[b];
Expr* aIndexStride = aStrides[b];
Expr* oIndexStride = oStrides[b];
// can't be safe on this index if we can't determine stride.
if (!aIndexStride->isConstant() || !oIndexStride->isConstant()) {
continue;
}
const Expr* minStride =
Expr* minStride =
IRSimplifier::simplify(new Min(aIndexStride, oIndexStride, true));
const Expr* maxStride =
Expr* maxStride =
IRSimplifier::simplify(new Max(aIndexStride, oIndexStride, true));
// If the first access has no stride don't apply safety).
@ -609,8 +607,7 @@ bool executionSafetyCheck(
continue;
}
const Expr* modCheck =
IRSimplifier::simplify(new Mod(maxStride, minStride));
Expr* modCheck = IRSimplifier::simplify(new Mod(maxStride, minStride));
// if the strides can't have easily inferable distinct offsets, they're not
// safe.
@ -624,7 +621,7 @@ bool executionSafetyCheck(
// axis is the same sign as the common stride, then they will not
// overlap.
const Expr* startDiff = IRSimplifier::simplify(
Expr* startDiff = IRSimplifier::simplify(
new Sub(info->bounds()[b].start, other->bounds()[b].start));
bool diffNegative = immediateIsNegative(startDiff);
@ -638,7 +635,7 @@ bool executionSafetyCheck(
// If both accesses have the same stride, and the difference in start
// element is smaller than this stride then the entire range is distinct.
if (exprEquals(minStride, maxStride)) {
const Expr* check1 =
Expr* check1 =
IRSimplifier::simplify(new CompareSelect(startDiff, minStride, kLT));
if (check1->isConstant() && immediateEquals(check1, 1)) {
return true;
@ -649,7 +646,7 @@ bool executionSafetyCheck(
CompareSelectOperation op = strideNegative ? kLT : kGT;
const Expr* check =
Expr* check =
IRSimplifier::simplify(new CompareSelect(startDiff, new IntImm(0), op));
// If the start difference modulo the minimum stride is offset from that
@ -670,10 +667,10 @@ bool executionSafetyCheck(
return false;
}
void MemDependencyChecker::visit(const For* v) {
const Var* var = v->var();
void MemDependencyChecker::visit(For* v) {
Var* var = v->var();
const Stmt* last = lastStmt_;
Stmt* last = lastStmt_;
lastStmt_ = v;
v->var()->accept(this);
@ -716,19 +713,19 @@ void MemDependencyChecker::visit(const For* v) {
// access, which we do via substituting the loop var with (var+1) into the
// indices expr.
std::vector<std::vector<const Expr*>> loopStrides;
std::vector<std::vector<Expr*>> loopStrides;
loopStrides.resize(currentScope_->accesses_.size());
for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) {
auto& info = currentScope_->accesses_[a];
std::vector<const Expr*> indices = info->getIndices();
std::vector<Expr*> indices = info->getIndices();
std::vector<const Expr*>& loopIndicesStride = loopStrides[a];
std::vector<Expr*>& loopIndicesStride = loopStrides[a];
loopIndicesStride.resize(indices.size());
// index expr must depend on the loop var in some way to have a stride.
for (const auto i : c10::irange(indices.size())) {
for (auto i : c10::irange(indices.size())) {
VarFinder vf;
if (vf.find(indices[i]).count(var) == 0) {
loopIndicesStride[i] = new IntImm(0);
@ -750,14 +747,14 @@ void MemDependencyChecker::visit(const For* v) {
{{var, new Sub(v->stop(), new IntImm(1))}}));
}
const Expr* zeroStep = indices[i];
const Expr* oneStep =
Expr* zeroStep = indices[i];
Expr* oneStep =
Substitute(indices[i], {{var, new Add(var, new IntImm(1))}});
loopIndicesStride[i] =
IRSimplifier::simplify(new Sub(oneStep, zeroStep));
// If the start < end then swap the order of the bound.
const Expr* diff = IRSimplifier::simplify(
Expr* diff = IRSimplifier::simplify(
new Sub(info->bounds()[i].end, info->bounds()[i].start));
if (diff->isConstant() && immediateIsNegative(diff)) {
info->bounds()[i].swap();
@ -788,8 +785,7 @@ void MemDependencyChecker::visit(const For* v) {
Substitute(bound.end, {{var, new Sub(v->stop(), new IntImm(1))}}));
// If the start < end then swap the order of the bound.
const Expr* diff =
IRSimplifier::simplify(new Sub(bound.end, bound.start));
Expr* diff = IRSimplifier::simplify(new Sub(bound.end, bound.start));
if (diff->isConstant() && immediateIsNegative(diff)) {
bound.swap();
}
@ -802,7 +798,7 @@ void MemDependencyChecker::visit(const For* v) {
v->loop_options().is_gpu_thread_index();
// Store buffers allocated at this scope.
std::unordered_set<const Var*> local_intermediates;
std::unordered_set<Var*> local_intermediates;
// Scanning from the top of the loop, we look for accesses which may depend
// on a previous or parallel loop iteration.
@ -905,8 +901,8 @@ void MemDependencyChecker::visit(const For* v) {
currentScope_ = currentScope_->parent;
}
void MemDependencyChecker::visit(const Cond* v) {
const Stmt* last = lastStmt_;
void MemDependencyChecker::visit(Cond* v) {
Stmt* last = lastStmt_;
lastStmt_ = v;
auto enclosingScope =
@ -954,12 +950,12 @@ void MemDependencyChecker::visit(const Cond* v) {
lastStmt_ = last;
}
void MemDependencyChecker::visit(const IfThenElse* v) {
void MemDependencyChecker::visit(IfThenElse* v) {
// condition is in enclosing scope.
v->condition()->accept(this);
const Expr* true_value = v->true_value();
const Expr* false_value = v->false_value();
Expr* true_value = v->true_value();
Expr* false_value = v->false_value();
auto enclosingScope = currentScope_;
@ -990,13 +986,13 @@ void MemDependencyChecker::visit(const IfThenElse* v) {
currentScope_ = enclosingScope;
}
void MemDependencyChecker::visit(const CompareSelect* v) {
void MemDependencyChecker::visit(CompareSelect* v) {
// condition is in enclosing scope.
v->lhs()->accept(this);
v->rhs()->accept(this);
const Expr* true_value = v->ret_val1();
const Expr* false_value = v->ret_val2();
Expr* true_value = v->ret_val1();
Expr* false_value = v->ret_val2();
auto enclosingScope = currentScope_;
@ -1029,11 +1025,11 @@ void MemDependencyChecker::visit(const CompareSelect* v) {
// Inserts accesses for a map of buffers (ie. for inputs and outputs).
void MemDependencyChecker::insertBuffers(
std::unordered_map<const Buf*, std::shared_ptr<AccessInfo>>& bufs,
std::unordered_map<Buf*, std::shared_ptr<AccessInfo>>& bufs,
AccessType type) {
for (auto& pair : bufs) {
const Buf* b = pair.first;
const Var* var = b->base_handle();
Buf* b = pair.first;
Var* var = b->base_handle();
IndexBounds bounds;
for (auto* d : b->dims()) {
bounds.push_back(
@ -1050,7 +1046,7 @@ void MemDependencyChecker::insertBuffers(
}
}
void MemDependencyChecker::visit(const Block* v) {
void MemDependencyChecker::visit(Block* v) {
auto prev_scope = currentScope_;
// handle kernel inputs.
@ -1086,15 +1082,15 @@ void MemDependencyChecker::visit(const Block* v) {
}
}
void MemDependencyChecker::visit(const Let* v) {
const Stmt* last = lastStmt_;
void MemDependencyChecker::visit(Let* v) {
Stmt* last = lastStmt_;
lastStmt_ = v;
IRVisitor::visit(v);
lastStmt_ = last;
const Var* var = v->var();
Var* var = v->var();
if (knownVarBounds_.count(var) != 0) {
currentScope_->shadowedVarBounds[var] = knownVarBounds_[var];
}
@ -1105,17 +1101,17 @@ void MemDependencyChecker::visit(const Let* v) {
// Don't support AtomicAdd yet, it's a bit more complex since it's both a read
// and a write. It's only inserted during Cuda codegen so this should be okay.
void MemDependencyChecker::visit(const AtomicAdd* v) {
void MemDependencyChecker::visit(AtomicAdd* v) {
throw std::runtime_error("MemDependencyChecker AtomicAdd unimplemented");
}
void MemDependencyChecker::visit(const Allocate* v) {
const Stmt* last = lastStmt_;
void MemDependencyChecker::visit(Allocate* v) {
Stmt* last = lastStmt_;
lastStmt_ = v;
IRVisitor::visit(v);
const Var* var = v->buffer_var();
Var* var = v->buffer_var();
IndexBounds bounds;
// TODO: remove the "buf_flat_size" process below and extend the buf bound
// check to support N-d indices access and 1-d index access.
@ -1124,7 +1120,7 @@ void MemDependencyChecker::visit(const Allocate* v) {
// identify 1-d index access for N-d bufs. Thus we flatten N-d bufs here to
// avoid failing the bound check. But this is not the correct approach and
// should be fixed.
const Expr* flat_size = buf_flat_size(v->buf());
Expr* flat_size = buf_flat_size(v->buf());
flat_size = IRSimplifier::simplify(new Sub(flat_size, new IntImm(1)));
bounds.push_back({new IntImm(0), flat_size});
@ -1140,13 +1136,13 @@ void MemDependencyChecker::visit(const Allocate* v) {
lastStmt_ = last;
}
void MemDependencyChecker::visit(const Free* v) {
const Stmt* last = lastStmt_;
void MemDependencyChecker::visit(Free* v) {
Stmt* last = lastStmt_;
lastStmt_ = v;
IRVisitor::visit(v);
const Var* var = v->buffer_var();
Var* var = v->buffer_var();
auto it = intermediates_.find(var);
TORCH_INTERNAL_ASSERT(it != intermediates_.end());
@ -1247,7 +1243,7 @@ void MemDependencyChecker::mergeScope(
// Copy open writes up.
for (auto& pair : child->openWrites_) {
const Var* var = pair.first;
Var* var = pair.first;
// Intentionally using operator[], we want it to be created if it does not
// exist.
@ -1270,7 +1266,7 @@ class VarBoundBinder : public IRVisitor {
public:
VarBoundBinder(const VarBoundMap& vars) : vars_(vars) {}
Bound getBounds(const Expr* e) {
Bound getBounds(Expr* e) {
min_ = e;
max_ = e;
e->accept(this);
@ -1280,7 +1276,7 @@ class VarBoundBinder : public IRVisitor {
}
private:
void visit(const Var* v) override {
void visit(Var* v) override {
auto it = vars_.find(v);
if (it == vars_.end()) {
return;
@ -1290,13 +1286,13 @@ class VarBoundBinder : public IRVisitor {
max_ = Substitute(max_, {{v, it->second.end}});
}
const Expr* min_{nullptr};
const Expr* max_{nullptr};
Expr* min_{nullptr};
Expr* max_{nullptr};
const VarBoundMap& vars_;
};
std::vector<Bound> MemDependencyChecker::getIndicesBounds(
const std::vector<const Expr*>& indices) {
const std::vector<Expr*>& indices) {
std::vector<Bound> bounds;
bounds.reserve(indices.size());
VarBoundBinder binder(knownVarBounds_);

View File

@ -40,8 +40,8 @@ class TORCH_API AccessInfo {
AccessInfo(
size_t id,
AccessType type,
const Stmt* stmt,
const Var* var,
Stmt* stmt,
Var* var,
IndexBounds bounds)
: id_(id),
type_(type),
@ -53,9 +53,9 @@ class TORCH_API AccessInfo {
AccessInfo(
size_t id,
AccessType type,
const Expr* expr,
const Stmt* stmt,
const Var* var,
Expr* expr,
Stmt* stmt,
Var* var,
IndexBounds bounds)
: id_(id),
type_(type),
@ -77,18 +77,18 @@ class TORCH_API AccessInfo {
// The enclosing Stmt this access represents. E.g. if this is a Store then
// Stmt is the Store itself, while if the access is caused by an Expr, this is
// the most immediate parent Stmt.
const Stmt* stmt() const {
Stmt* stmt() const {
return stmt_;
}
// If the access is represented by an Expr (such as Load or Call) then this is
// it, otherwise it's nullptr.
const Expr* expr() const {
Expr* expr() const {
return expr_;
}
// The Var representing the underlying Buffer.
const Var* var() const {
Var* var() const {
return var_;
}
@ -114,7 +114,7 @@ class TORCH_API AccessInfo {
}
// Returns the symbolic expression of the indices of this access.
std::vector<const Expr*> getIndices() const;
std::vector<Expr*> getIndices() const;
// Establishes a dependency or dependent relationship with another access.
void addDependency(const std::shared_ptr<AccessInfo>& write);
@ -149,9 +149,9 @@ class TORCH_API AccessInfo {
private:
size_t id_;
AccessType type_;
const Stmt* stmt_;
const Expr* expr_;
const Var* var_;
Stmt* stmt_;
Expr* expr_;
Var* var_;
IndexBounds bounds_;
// Yes these should be sorted.
@ -159,7 +159,7 @@ class TORCH_API AccessInfo {
std::map<size_t, std::shared_ptr<AccessInfo>> dependents_;
};
using VarBoundMap = std::unordered_map<const Var*, Bound>;
using VarBoundMap = std::unordered_map<Var*, Bound>;
/* MemDepedencyChecker analyses a IR fragment and builds a dependency graph of
* accesses contained within.
@ -176,8 +176,8 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
public:
MemDependencyChecker();
MemDependencyChecker(
const std::unordered_set<const Buf*>& inputs,
const std::unordered_set<const Buf*>& outputs);
const std::unordered_set<Buf*>& inputs,
const std::unordered_set<Buf*>& outputs);
MemDependencyChecker(
const std::vector<BufHandle>& inputs,
const std::vector<BufHandle>& outputs);
@ -193,15 +193,15 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
// about it.
// Returns true if any read in A has a direct dependence on a write in B.
bool dependsDirectly(const Stmt* A, const Stmt* B);
bool dependsDirectly(const Expr* A, const Stmt* B);
bool dependsDirectly(Stmt* A, Stmt* B);
bool dependsDirectly(Expr* A, Stmt* B);
// Returns true of the output depends directly on a write contained in B.
bool dependsDirectly(const Buf* output, const Stmt* B);
bool dependsDirectly(Buf* output, Stmt* B);
// Returns true if a read in A depends directly on the provided input.
bool dependsDirectly(const Stmt* A, const Buf* input);
bool dependsDirectly(const Expr* A, const Buf* input);
bool dependsDirectly(Stmt* A, Buf* input);
bool dependsDirectly(Expr* A, Buf* input);
// Outputs/inputs cannot depend directly.
@ -211,18 +211,18 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
const std::shared_ptr<AccessInfo>& B);
// Returns true if any read in A has an ancestor write contained in B.
bool dependsIndirectly(const Stmt* A, const Stmt* B);
bool dependsIndirectly(const Expr* A, const Stmt* B);
bool dependsIndirectly(Stmt* A, Stmt* B);
bool dependsIndirectly(Expr* A, Stmt* B);
// Returns true of the output depends indirectly on a write contained in B.
bool dependsIndirectly(const Buf* output, const Stmt* B);
bool dependsIndirectly(Buf* output, Stmt* B);
// Returns true if a read in A depends indirectly on the provided input.
bool dependsIndirectly(const Stmt* A, const Buf* input);
bool dependsIndirectly(const Expr* A, const Buf* input);
bool dependsIndirectly(Stmt* A, Buf* input);
bool dependsIndirectly(Expr* A, Buf* input);
// returns true if the output uses any load of the input.
bool dependsIndirectly(const Buf* output, const Buf* input);
bool dependsIndirectly(Buf* output, Buf* input);
// Returns true if the access A has a dependency chain to access B.
bool dependsIndirectly(
@ -230,21 +230,19 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
const std::shared_ptr<AccessInfo>& B);
// Returns the AccessInfo
std::shared_ptr<AccessInfo> accessFor(const Stmt* A) const;
std::shared_ptr<AccessInfo> accessFor(const Expr* A) const;
std::shared_ptr<AccessInfo> accessFor(Stmt* A) const;
std::shared_ptr<AccessInfo> accessFor(Expr* A) const;
// Returns all AccessInfos.
std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(
const Stmt* A) const;
std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(Stmt* A) const;
// TODO: this will return only the AccessInfo for A. It's included for
// completeness but be aware it wont return accesses used in the computation
// of A.
std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(
const Expr* A) const;
std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(Expr* A) const;
// Accesses relating to input and output buffers.
std::shared_ptr<AccessInfo> input(const Buf* B) const;
std::shared_ptr<AccessInfo> output(const Buf* B) const;
std::shared_ptr<AccessInfo> input(Buf* B) const;
std::shared_ptr<AccessInfo> output(Buf* B) const;
// Returns the full history of reads and writes.
const std::vector<std::shared_ptr<AccessInfo>>& getHistory() const;
@ -254,17 +252,17 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
private:
// Node visitors.
void visit(const Store* v) override;
void visit(const Load* v) override;
void visit(const For* v) override;
void visit(const Cond* v) override;
void visit(const IfThenElse* v) override;
void visit(const CompareSelect* v) override;
void visit(const Block* v) override;
void visit(const Let* v) override;
void visit(const AtomicAdd* v) override;
void visit(const Allocate* v) override;
void visit(const Free* v) override;
void visit(Store* v) override;
void visit(Load* v) override;
void visit(For* v) override;
void visit(Cond* v) override;
void visit(IfThenElse* v) override;
void visit(CompareSelect* v) override;
void visit(Block* v) override;
void visit(Let* v) override;
void visit(AtomicAdd* v) override;
void visit(Allocate* v) override;
void visit(Free* v) override;
using BoundRelationship = std::pair<IndexBounds, std::shared_ptr<AccessInfo>>;
@ -276,29 +274,27 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
Block* block;
std::shared_ptr<Scope> parent;
std::unordered_map<const Var*, Bound> shadowedVarBounds;
std::unordered_set<const Var*> localVars;
std::unordered_map<Var*, Bound> shadowedVarBounds;
std::unordered_set<Var*> localVars;
std::vector<std::shared_ptr<AccessInfo>> accesses_;
std::unordered_map<const Var*, std::list<BoundRelationship>> openWrites_;
std::unordered_map<Var*, std::list<BoundRelationship>> openWrites_;
};
std::shared_ptr<Scope> currentScope_;
bool allowExecutionOrderAnalysis_{false};
std::unordered_multimap<const Stmt*, std::shared_ptr<AccessInfo>>
stmtToAccess_;
std::unordered_multimap<const Expr*, std::shared_ptr<AccessInfo>>
exprToAccess_;
std::unordered_map<const Stmt*, std::vector<std::shared_ptr<AccessInfo>>>
std::unordered_multimap<Stmt*, std::shared_ptr<AccessInfo>> stmtToAccess_;
std::unordered_multimap<Expr*, std::shared_ptr<AccessInfo>> exprToAccess_;
std::unordered_map<Stmt*, std::vector<std::shared_ptr<AccessInfo>>>
scopeToAccesses_;
VarBoundMap knownVarBounds_;
// Finds all accesses that are reads within the scope of v.
template <typename StmtOrExpr>
DependencySet getAllReadsWithin(const StmtOrExpr* v) {
DependencySet getAllReadsWithin(StmtOrExpr* v) {
DependencySet reads;
auto insertAllReads = [&](const auto& nodes) {
for (auto* l : nodes) {
@ -321,7 +317,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
// Finds all accesses that are writes within the scope of v.
// Writes cannot occur in Exprs, so this is a little simpler.
DependencySet getAllWritesWithin(const Stmt* v) {
DependencySet getAllWritesWithin(Stmt* v) {
DependencySet writes;
// writes just Store currently.
@ -339,7 +335,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
// Templated helpers to work on either Exprs or Stmts.
template <typename StmtOrExpr>
bool dependsDirectlyHelper(const StmtOrExpr* A, const Stmt* B) {
bool dependsDirectlyHelper(StmtOrExpr* A, Stmt* B) {
auto aReads = getAllReadsWithin(A);
auto bWrites = getAllWritesWithin(B);
@ -355,7 +351,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
}
template <typename StmtOrExpr>
bool dependsIndirectlyHelper(const StmtOrExpr* A, const Stmt* B) {
bool dependsIndirectlyHelper(StmtOrExpr* A, Stmt* B) {
auto aReads = getAllReadsWithin(A);
auto bWrites = getAllWritesWithin(B);
@ -373,13 +369,13 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
DependencySet getAllWriteDependencies(const DependencySet& products);
// Maps for inputs and outputs, since they aren't present directly in the IR.
std::unordered_map<const Buf*, std::shared_ptr<AccessInfo>> inputs_;
std::unordered_map<const Buf*, std::shared_ptr<AccessInfo>> outputs_;
std::unordered_map<const Var*, std::shared_ptr<AccessInfo>> intermediates_;
std::unordered_map<Buf*, std::shared_ptr<AccessInfo>> inputs_;
std::unordered_map<Buf*, std::shared_ptr<AccessInfo>> outputs_;
std::unordered_map<Var*, std::shared_ptr<AccessInfo>> intermediates_;
// Inserts accesses for Buf's: specifically for inputs and outputs.
void insertBuffers(
std::unordered_map<const Buf*, std::shared_ptr<AccessInfo>>& bufs,
std::unordered_map<Buf*, std::shared_ptr<AccessInfo>>& bufs,
AccessType type);
// Update the write history with a new write, adding dependencies and closing
@ -399,10 +395,10 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
bool closeOverlapped = true);
// Binds symbolic vars in indices with the low and high bound for those vars.
std::vector<Bound> getIndicesBounds(const std::vector<const Expr*>& indices);
std::vector<Bound> getIndicesBounds(const std::vector<Expr*>& indices);
size_t nextAccess_{0};
const Stmt* lastStmt_{nullptr};
Stmt* lastStmt_{nullptr};
};
} // namespace analysis

View File

@ -7,19 +7,19 @@ namespace jit {
namespace tensorexpr {
ReduceOp* Reducer::operator()(
const Buf* result_buf,
Buf* result_buf,
ExprHandle body,
const std::vector<const Expr*>& output,
const std::vector<const Var*>& inner) const {
const std::vector<Expr*>& output,
const std::vector<Var*>& inner) const {
return new ReduceOp(
complete(result_buf, interaction_, body, output, inner), inner, *this);
}
ReduceOp* Reducer::operator()(
const Buf* result_buf,
const Expr* body,
const std::vector<const Expr*>& output,
const std::vector<const Var*>& inner) const {
Buf* result_buf,
Expr* body,
const std::vector<Expr*>& output,
const std::vector<Var*>& inner) const {
return new ReduceOp(
complete(result_buf, interaction_, ExprHandle(body), output, inner),
inner,

View File

@ -35,21 +35,21 @@ class TORCH_API Reducer {
}
virtual ~Reducer() = default;
const Expr* initializer() const {
Expr* initializer() const {
return init_;
}
ReduceOp* operator()(
const Buf* result_buf,
Buf* result_buf,
ExprHandle body,
const std::vector<const Expr*>& output,
const std::vector<const Var*>& inner) const;
const std::vector<Expr*>& output,
const std::vector<Var*>& inner) const;
ReduceOp* operator()(
const Buf* result_buf,
const Expr* body,
const std::vector<const Expr*>& output,
const std::vector<const Var*>& inner) const;
Buf* result_buf,
Expr* body,
const std::vector<Expr*>& output,
const std::vector<Var*>& inner) const;
// Polymorphic handling of Body functions with a variety of parameters.
static ExprHandle getReduceBody(
@ -104,11 +104,11 @@ class TORCH_API Reducer {
// Completes the reduction operator by applying the interaction function to
// the accumulation and the body expression.
static Expr* complete(
const Buf* accumulator,
Buf* accumulator,
ReduceInteraction interaction,
ExprHandle body,
const std::vector<const Expr*>& output_args,
const std::vector<const Var*>& reduce_args) {
const std::vector<Expr*>& output_args,
const std::vector<Var*>& reduce_args) {
ExprHandle accum =
ExprHandle(new Load(body.dtype(), accumulator, output_args));
auto e = interaction(accum, body);
@ -116,7 +116,7 @@ class TORCH_API Reducer {
}
private:
const Expr* init_;
Expr* init_;
ReduceInteraction interaction_;
};
@ -128,17 +128,14 @@ class TORCH_API Reducer {
class TORCH_API ReduceOp : public ExprNode<ReduceOp> {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
ReduceOp(
const Expr* body,
std::vector<const Var*> reduce_args,
const Reducer& reducer)
ReduceOp(Expr* body, std::vector<Var*> reduce_args, const Reducer& reducer)
: ExprNodeBase(body->dtype()),
body_(body),
reduce_args_(std::move(reduce_args)),
reducer_(reducer) {}
// return the body expression which obtains the value to be reduced.
const Expr* body() const {
Expr* body() const {
return body_;
}
@ -148,13 +145,13 @@ class TORCH_API ReduceOp : public ExprNode<ReduceOp> {
}
// returns variables associated with the axes of reduction.
const std::vector<const Var*>& reduce_args() const {
const std::vector<Var*>& reduce_args() const {
return reduce_args_;
}
private:
const Expr* body_;
std::vector<const Var*> reduce_args_;
Expr* body_;
std::vector<Var*> reduce_args_;
const Reducer reducer_;
};
@ -223,7 +220,7 @@ class ReductionExpander : public IRMutator {
return s->accept_mutator(this);
}
const Expr* mutate(const ReduceOp* v) override {
Expr* mutate(ReduceOp* v) override {
return v->body();
}
};

View File

@ -7,9 +7,7 @@ namespace registerizer {
// AccessInfo
void AccessInfo::addStore(
const Store* store,
const std::shared_ptr<Scope>& scope) {
void AccessInfo::addStore(Store* store, const std::shared_ptr<Scope>& scope) {
block_ =
block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
@ -27,9 +25,9 @@ void AccessInfo::addStore(
}
void AccessInfo::addLoad(
const Load* load,
Load* load,
const std::shared_ptr<Scope>& scope,
const Stmt* usage) {
Stmt* usage) {
block_ =
block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage;
@ -69,13 +67,13 @@ bool AccessInfo::overlaps(const std::shared_ptr<AccessInfo>& other) {
// All accesses to a buf must have the same dimensionality.
TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size());
const auto& other_indices = other->indices();
auto& other_indices = other->indices();
// They don't overlap if there is a guaranteed difference in any
// dimension.
bool overlap = true;
for (size_t i = 0; i < indices_.size(); ++i) {
const Expr* diff = new Sub(indices_[i], other_indices[i]);
Expr* diff = new Sub(indices_[i], other_indices[i]);
diff = IRSimplifier::simplify(diff);
if (diff->isConstant() && !immediateEquals(diff, 0)) {
@ -87,7 +85,7 @@ bool AccessInfo::overlaps(const std::shared_ptr<AccessInfo>& other) {
return overlap;
}
bool AccessInfo::dependsOnVar(const Var* v) {
bool AccessInfo::dependsOnVar(Var* v) {
VarFinder vf;
for (auto* i : indices_) {
i->accept(&vf);
@ -139,7 +137,7 @@ void Scope::closeAccess(const std::shared_ptr<AccessInfo>& info) {
closedAccesses_.push_back(info);
}
AccessHashMap& Scope::getAccessMapByBuf(const Buf* b) {
AccessHashMap& Scope::getAccessMapByBuf(Buf* b) {
auto it = openAccesses_.find(b);
if (it == openAccesses_.end()) {
// create and return
@ -179,7 +177,7 @@ void RegisterizerAnalysis::closeAccessIntoScope(
scope->closeAccess(info);
}
void RegisterizerAnalysis::visit(const For* v) {
void RegisterizerAnalysis::visit(For* v) {
if (v->loop_options().is_gpu_block_index() ||
v->loop_options().is_gpu_thread_index()) {
throw malformed_input(
@ -195,8 +193,7 @@ void RegisterizerAnalysis::visit(const For* v) {
v->body()->accept(this);
stmtStack_.pop_front();
const Expr* loopExtent =
IRSimplifier::simplify(new Sub(v->stop(), v->start()));
Expr* loopExtent = IRSimplifier::simplify(new Sub(v->stop(), v->start()));
// now we need to see which accesses we can hoist out of the for loop, their
// costs should be multiplied by the loop extent.
@ -263,8 +260,8 @@ void RegisterizerAnalysis::visit(const For* v) {
mergeCurrentScopeIntoParent();
};
void RegisterizerAnalysis::visit(const Cond* v) {
const Expr* condition = v->condition();
void RegisterizerAnalysis::visit(Cond* v) {
Expr* condition = v->condition();
Block* true_stmt = v->true_stmt();
Block* false_stmt = v->false_stmt();
@ -303,10 +300,10 @@ void RegisterizerAnalysis::visit(const Cond* v) {
// IfThenElses are just like Conds except they are not Stmts, which means no
// registerization can occur internally. However, the first reference to an
// access can occur within one if its visible outside the condition.
void RegisterizerAnalysis::visit(const IfThenElse* v) {
const Expr* condition = v->condition();
const Expr* true_value = v->true_value();
const Expr* false_value = v->false_value();
void RegisterizerAnalysis::visit(IfThenElse* v) {
Expr* condition = v->condition();
Expr* true_value = v->true_value();
Expr* false_value = v->false_value();
// condition is in enclosing scope.
condition->accept(this);
@ -338,7 +335,7 @@ void RegisterizerAnalysis::visit(const IfThenElse* v) {
}
}
void RegisterizerAnalysis::visit(const Let* v) {
void RegisterizerAnalysis::visit(Let* v) {
currentScope_->addLocalVar(v->var());
stmtStack_.push_front(v);
@ -346,7 +343,7 @@ void RegisterizerAnalysis::visit(const Let* v) {
stmtStack_.pop_front();
}
void RegisterizerAnalysis::visit(const Block* v) {
void RegisterizerAnalysis::visit(Block* v) {
auto prev_scope = currentScope_;
if (currentScope_->block() != v) {
currentScope_ = std::make_shared<Scope>(v, prev_scope);
@ -374,7 +371,7 @@ void RegisterizerAnalysis::visit(const Block* v) {
}
}
void RegisterizerAnalysis::visit(const Store* v) {
void RegisterizerAnalysis::visit(Store* v) {
stmtStack_.push_front(v);
v->value()->accept(this);
stmtStack_.pop_front();
@ -428,7 +425,7 @@ void RegisterizerAnalysis::visit(const Store* v) {
}
}
void RegisterizerAnalysis::visit(const Load* v) {
void RegisterizerAnalysis::visit(Load* v) {
if (v->indices().empty()) {
// already a scalar.
return;
@ -563,7 +560,7 @@ void RegisterizerAnalysis::mergeCurrentScopeIntoParent() {
// copy across current open accesses, merging as necessary.
// for each Buf with an open access:
for (auto& pair : currentScope_->openAccesses()) {
const Buf* buf = pair.first;
Buf* buf = pair.first;
if (pair.second.empty()) {
continue;
}
@ -640,7 +637,7 @@ std::vector<std::shared_ptr<AccessInfo>> RegisterizerAnalysis::getCandidates() {
// RegisterizerReplacer
const Expr* RegisterizerReplacer::mutate(const Load* v) {
Expr* RegisterizerReplacer::mutate(Load* v) {
auto it = loadToAccess_.find(v);
if (it == loadToAccess_.end()) {
// This access cannot be registerized.
@ -652,7 +649,7 @@ const Expr* RegisterizerReplacer::mutate(const Load* v) {
return info->replacement().var;
}
Stmt* RegisterizerReplacer::mutate(const Store* v) {
Stmt* RegisterizerReplacer::mutate(Store* v) {
if (eliminatedIntializers_.count(v) != 0) {
// This store is the intializer for a scalar var that is already inserted.
return nullptr;
@ -666,12 +663,12 @@ Stmt* RegisterizerReplacer::mutate(const Store* v) {
auto& info = it->second;
const Expr* new_val = v->value()->accept_mutator(this);
Expr* new_val = v->value()->accept_mutator(this);
return new Store(info->replacement().var_wrapper, {}, new_val);
}
Stmt* RegisterizerReplacer::mutate(const Block* v) {
Stmt* RegisterizerReplacer::mutate(Block* v) {
auto& scope = parentToAccesses_[v];
std::vector<Stmt*> stmts;

View File

@ -23,7 +23,7 @@ For example it can replace:
{
A[0] = 0;
for(const auto x : c10::irange(10)) {
for(auto x : c10::irange(10)) {
A[0] = (A[0]) + x;
}
}
@ -32,7 +32,7 @@ with:
{
int A_ = 0;
for(const auto x : c10::irange(10)) {
for(auto x : c10::irange(10)) {
A_ = x + A_;
}
A[0] = A_;
@ -54,8 +54,8 @@ class AccessInfo {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AccessInfo(
SimplifierHashType h,
const Buf* b,
std::vector<const Expr*> i,
Buf* b,
std::vector<Expr*> i,
size_t accessOrder)
: hash_(h),
buf_(b),
@ -65,14 +65,11 @@ class AccessInfo {
accessOrder_(accessOrder) {}
// Adds a Store to this access, which is in the provided scope.
void addStore(const Store* store, const std::shared_ptr<Scope>& scope);
void addStore(Store* store, const std::shared_ptr<Scope>& scope);
// Adds a Load to this access, which occurs in the usage Stmt in the provided
// scope.
void addLoad(
const Load* load,
const std::shared_ptr<Scope>& scope,
const Stmt* usage);
void addLoad(Load* load, const std::shared_ptr<Scope>& scope, Stmt* usage);
// Merge another AccessInfo into this one.
void merge(const std::shared_ptr<AccessInfo>& other);
@ -81,7 +78,7 @@ class AccessInfo {
bool overlaps(const std::shared_ptr<AccessInfo>& other);
// Returns true if the indices of this access depend on the provided Var.
bool dependsOnVar(const Var* v);
bool dependsOnVar(Var* v);
// Clone this AccessInfo, and set this as the new accesses' hiddenAccess.
static std::shared_ptr<AccessInfo> cloneWithHiddenInfo(
@ -94,30 +91,30 @@ class AccessInfo {
return hash_;
}
const Buf* buf() const {
Buf* buf() const {
return buf_;
}
const std::vector<const Expr*>& indices() const {
const std::vector<Expr*>& indices() const {
return indices_;
}
const Block* block() const {
Block* block() const {
return block_;
}
void setEnclosingBlock(const Block* b) {
void setEnclosingBlock(Block* b) {
block_ = b;
}
const Stmt* first_usage() const {
Stmt* first_usage() const {
return first_usage_;
}
const Stmt* last_usage() const {
Stmt* last_usage() const {
return last_usage_;
}
void setUsageMarks(const Stmt* first, const Stmt* last) {
void setUsageMarks(Stmt* first, Stmt* last) {
first_usage_ = first;
last_usage_ = last;
}
@ -126,23 +123,23 @@ class AccessInfo {
return firstUsageOverlapped_;
}
const Expr* store_cost() const {
Expr* store_cost() const {
return store_cost_;
}
const Expr* load_cost() const {
Expr* load_cost() const {
return load_cost_;
}
const std::vector<const Store*>& stores() const {
const std::vector<Store*>& stores() const {
return stores_;
}
const std::vector<const Load*>& loads() const {
const std::vector<Load*>& loads() const {
return loads_;
}
void hoistCosts(const Expr* extent) {
void hoistCosts(Expr* extent) {
store_cost_ = IRSimplifier::simplify(new Mul(store_cost_, extent));
load_cost_ = IRSimplifier::simplify(new Mul(load_cost_, extent));
}
@ -177,12 +174,12 @@ class AccessInfo {
private:
SimplifierHashType hash_;
const Buf* buf_;
std::vector<const Expr*> indices_;
const Block* block_{nullptr};
Buf* buf_;
std::vector<Expr*> indices_;
Block* block_{nullptr};
const Stmt* first_usage_{nullptr};
const Stmt* last_usage_{nullptr};
Stmt* first_usage_{nullptr};
Stmt* last_usage_{nullptr};
// Whether or not this access is overlapped in the first Stmt it appears. This
// means we cannot use it's first Store as the initializer.
@ -190,13 +187,13 @@ class AccessInfo {
// The cost in real ops that this access represents, to enable
// filtering accesses that wont save any loads or stores.
const Expr* store_cost_;
const Expr* load_cost_;
Expr* store_cost_;
Expr* load_cost_;
// The actual Stores and Loads which represent this access.
// Be careful with these, any mutator will invalidate these pointers.
std::vector<const Store*> stores_;
std::vector<const Load*> loads_;
std::vector<Store*> stores_;
std::vector<Load*> loads_;
// An identifier representing the conditional block, if any, this access
// depends on.
@ -222,12 +219,12 @@ using AccessHashMap =
class Scope {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Scope(const Block* b, std::shared_ptr<Scope> parent, size_t conditionId = 0)
Scope(Block* b, std::shared_ptr<Scope> parent, size_t conditionId = 0)
: block_(b), parent_(std::move(parent)), conditionId_(conditionId) {}
AccessHashMap& getAccessMapByBuf(const Buf* b);
AccessHashMap& getAccessMapByBuf(Buf* b);
std::unordered_map<const Buf*, AccessHashMap>& openAccesses() {
std::unordered_map<Buf*, AccessHashMap>& openAccesses() {
return openAccesses_;
}
@ -235,7 +232,7 @@ class Scope {
return closedAccesses_;
}
const Block* block() const {
Block* block() const {
return block_;
}
@ -247,10 +244,10 @@ class Scope {
return conditionId_;
}
const std::unordered_set<const Var*>& localVars() const {
const std::unordered_set<Var*>& localVars() const {
return localVars_;
}
void addLocalVar(const Var* v) {
void addLocalVar(Var* v) {
localVars_.insert(v);
}
@ -264,11 +261,11 @@ class Scope {
// overlap with other accesses to the same buf. Buf ->
// Hash ->
// Access
std::unordered_map<const Buf*, AccessHashMap> openAccesses_;
std::unordered_map<Buf*, AccessHashMap> openAccesses_;
std::vector<std::shared_ptr<AccessInfo>> closedAccesses_;
// The Block object this scope represents.
const Block* block_;
Block* block_;
// The enclosing scope object.
std::shared_ptr<Scope> parent_;
@ -277,7 +274,7 @@ class Scope {
size_t conditionId_;
// A set of variables local to this scope (e.g. loop vars).
std::unordered_set<const Var*> localVars_;
std::unordered_set<Var*> localVars_;
};
/* Analyzes the graph and collects accesses to the same symbolic tensor element
@ -323,25 +320,25 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor {
: currentScope_(std::make_shared<Scope>(nullptr, nullptr, 0)) {}
~RegisterizerAnalysis() override = default;
void visit(const For* v) override;
void visit(For* v) override;
void visit(const Cond* v) override;
void visit(Cond* v) override;
void visit(const Block* v) override;
void visit(Block* v) override;
void visit(const Store* v) override;
void visit(Store* v) override;
void visit(const Load* v) override;
void visit(Load* v) override;
void visit(const IfThenElse* v) override;
void visit(IfThenElse* v) override;
void visit(const Let* v) override;
void visit(Let* v) override;
#define STMT_ON_STACK(Op) \
void visit(const Op* v) override { \
stmtStack_.push_front(v); \
IRVisitor::visit(v); \
stmtStack_.pop_front(); \
#define STMT_ON_STACK(Op) \
void visit(Op* v) override { \
stmtStack_.push_front(v); \
IRVisitor::visit(v); \
stmtStack_.pop_front(); \
}
STMT_ON_STACK(AtomicAdd);
@ -362,7 +359,7 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor {
std::unordered_set<size_t> exprConditionals_;
// A stack of enclosing Stmts for tracking the usage Stmt of Loads.
std::deque<const Stmt*> stmtStack_;
std::deque<Stmt*> stmtStack_;
// The current scope being analyzed.
std::shared_ptr<Scope> currentScope_;
@ -384,17 +381,17 @@ class TORCH_API RegisterizerReplacer : public IRMutator {
buildReplacements();
}
const Expr* mutate(const Load* v) override;
Expr* mutate(Load* v) override;
Stmt* mutate(const Store* v) override;
Stmt* mutate(Store* v) override;
Stmt* mutate(const Block* v) override;
Stmt* mutate(Block* v) override;
private:
struct ReplacerScope {
std::unordered_map<const Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
std::unordered_map<Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
initializerPoints_;
std::unordered_map<const Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
std::unordered_map<Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
finalizePoints_;
};
@ -403,18 +400,18 @@ class TORCH_API RegisterizerReplacer : public IRMutator {
// State relating to the accesses yet to be replaced.
std::vector<std::shared_ptr<AccessInfo>>& infoSet_;
std::unordered_map<const Store*, std::shared_ptr<AccessInfo>> storeToAccess_;
std::unordered_map<const Load*, std::shared_ptr<AccessInfo>> loadToAccess_;
std::unordered_map<const Block*, ReplacerScope> parentToAccesses_;
std::unordered_map<Store*, std::shared_ptr<AccessInfo>> storeToAccess_;
std::unordered_map<Load*, std::shared_ptr<AccessInfo>> loadToAccess_;
std::unordered_map<Block*, ReplacerScope> parentToAccesses_;
// Holds the set of Stores that should be pulled into an initializer, so they
// can be eliminated.
std::set<const Store*> eliminatedIntializers_;
std::set<Store*> eliminatedIntializers_;
// Tracks the number of times we've seen each buffer, so we can name the
// scalar Vars appropriately.
std::unordered_map<const Buf*, unsigned int> bufferAccessCounts_;
unsigned int getBufferAccessCount(const Buf* b) {
std::unordered_map<Buf*, unsigned int> bufferAccessCounts_;
unsigned int getBufferAccessCount(Buf* b) {
return ++bufferAccessCounts_[b];
}
};

View File

@ -17,7 +17,7 @@ class Placeholder;
class TORCH_API Stmt : public KernelScopedObject {
public:
Stmt() = default;
virtual void accept(IRVisitor* visitor) const = 0;
virtual void accept(IRVisitor* visitor) = 0;
virtual Stmt* accept_mutator(IRMutator* mutator) = 0;
Stmt* get_parent() const {
@ -46,8 +46,8 @@ template <class Op>
class StmtNode : public Stmt {
public:
using StmtNodeBase = StmtNode<Op>;
void accept(IRVisitor* visitor) const override {
visitor->visit(static_cast<const Op*>(this));
void accept(IRVisitor* visitor) override {
visitor->visit(static_cast<Op*>(this));
}
Stmt* accept_mutator(IRMutator* mutator) override;
StmtNode() = default;
@ -102,7 +102,7 @@ class TORCH_API Block : public StmtNode<Block> {
set_parent(s, this);
}
void insert_stmt_before(Stmt* s, const Stmt* before) {
void insert_stmt_before(Stmt* s, Stmt* before) {
if (s->get_parent()) {
throw malformed_input("Block append Stmt with existing parent", s);
}
@ -117,7 +117,7 @@ class TORCH_API Block : public StmtNode<Block> {
set_parent(s, this);
}
void insert_stmt_after(Stmt* s, const Stmt* after) {
void insert_stmt_after(Stmt* s, Stmt* after) {
if (s->get_parent()) {
throw malformed_input("Block append Stmt with existing parent", s);
}
@ -240,7 +240,7 @@ class TORCH_API Block : public StmtNode<Block> {
return stmts_.front();
}
const Stmt* front() const {
Stmt* front() const {
return stmts_.front();
}
@ -248,7 +248,7 @@ class TORCH_API Block : public StmtNode<Block> {
return stmts_.back();
}
const Stmt* back() const {
Stmt* back() const {
return stmts_.back();
}
@ -260,13 +260,13 @@ class TORCH_API Block : public StmtNode<Block> {
stmts_.splice(it, other->stmts_);
}
static const Block* getSharedParent(const Stmt* p1, const Stmt* p2) {
static Block* getSharedParent(Stmt* p1, Stmt* p2) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::unordered_set<const Block*> enclosing;
std::unordered_set<Block*> enclosing;
const Stmt* p1_p = p1;
Stmt* p1_p = p1;
while (p1_p) {
if (const Block* b = dynamic_cast<const Block*>(p1_p)) {
if (Block* b = dynamic_cast<Block*>(p1_p)) {
if (b) {
enclosing.insert(b);
}
@ -274,9 +274,9 @@ class TORCH_API Block : public StmtNode<Block> {
p1_p = p1_p->get_parent();
}
const Stmt* p2_p = p2;
Stmt* p2_p = p2;
while (p2_p) {
if (const Block* b = dynamic_cast<const Block*>(p2_p)) {
if (Block* b = dynamic_cast<Block*>(p2_p)) {
if (enclosing.count(b) != 0) {
return b;
}
@ -288,7 +288,7 @@ class TORCH_API Block : public StmtNode<Block> {
}
// returns the immediate child containing statement s.
const Stmt* getEnclosedRoot(const Stmt* s) const {
Stmt* getEnclosedRoot(Stmt* s) const {
while (s && s->get_parent() != this) {
s = s->get_parent();
}
@ -301,20 +301,20 @@ class TORCH_API Block : public StmtNode<Block> {
class TORCH_API Store : public StmtNode<Store> {
public:
const Var* base_handle() const {
Var* base_handle() const {
return buf_->base_handle();
}
std::vector<const Expr*> indices() const {
std::vector<Expr*> indices() const {
return indices_;
}
const Expr* flat_index() const {
Expr* flat_index() const {
TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
return indices_[0];
}
const Expr* value() const {
Expr* value() const {
return value_;
}
const Buf* buf() const {
Buf* buf() const {
return buf_;
}
@ -323,16 +323,16 @@ class TORCH_API Store : public StmtNode<Store> {
const std::vector<ExprHandle>& indices,
const ExprHandle& value);
Store(const Buf* buf, std::vector<const Expr*> indices, const Expr* value);
Store(Buf* buf, std::vector<Expr*> indices, Expr* value);
void set_indices(std::vector<const Expr*> indices) {
void set_indices(std::vector<Expr*> indices) {
indices_ = indices;
};
private:
const Buf* buf_;
std::vector<const Expr*> indices_;
const Expr* value_;
Buf* buf_;
std::vector<Expr*> indices_;
Expr* value_;
};
// Allocate a buffer of given shapes and dtypes and bind it with the given
@ -344,7 +344,7 @@ class TORCH_API Allocate : public StmtNode<Allocate> {
return new Allocate(buf_handle.node());
}
const Var* buffer_var() const {
Var* buffer_var() const {
return buf_->base_handle();
}
@ -352,18 +352,18 @@ class TORCH_API Allocate : public StmtNode<Allocate> {
return buf_->dtype();
}
const std::vector<const Expr*> dims() const {
const std::vector<Expr*> dims() const {
return buf_->dims();
}
const Buf* buf() const {
Buf* buf() const {
return buf_;
}
explicit Allocate(const Buf* buf) : buf_(buf) {}
explicit Allocate(Buf* buf) : buf_(buf) {}
private:
const Buf* buf_;
Buf* buf_;
// TODO: add memory types.
};
@ -374,18 +374,18 @@ class TORCH_API Free : public StmtNode<Free> {
return new Free(buf_handle.node());
}
const Var* buffer_var() const {
Var* buffer_var() const {
return buf_->base_handle();
}
const Buf* buf() const {
Buf* buf() const {
return buf_;
}
explicit Free(const Buf* buf) : buf_(buf) {}
explicit Free(Buf* buf) : buf_(buf) {}
private:
const Buf* buf_;
Buf* buf_;
};
class TORCH_API Let : public StmtNode<Let> {
@ -394,25 +394,24 @@ class TORCH_API Let : public StmtNode<Let> {
return new Let(var.node(), val.node());
}
Let(const Var* var, const Expr* val)
: dtype_(var->dtype()), var_(var), val_(val) {}
Let(Var* var, Expr* val) : dtype_(var->dtype()), var_(var), val_(val) {}
Dtype dtype() const {
return dtype_;
}
const Var* var() const {
Var* var() const {
return var_;
}
const Expr* value() const {
Expr* value() const {
return val_;
}
private:
Dtype dtype_;
const Var* var_;
const Expr* val_;
Var* var_;
Expr* val_;
};
class TORCH_API Cond : public StmtNode<Cond> {
@ -424,7 +423,7 @@ class TORCH_API Cond : public StmtNode<Cond> {
return new Cond(condition.node(), true_stmt, false_stmt);
}
const Expr* condition() const {
Expr* condition() const {
return condition_;
}
@ -436,7 +435,7 @@ class TORCH_API Cond : public StmtNode<Cond> {
return false_stmt_;
}
Cond(const Expr* condition, Stmt* true_stmt, Stmt* false_stmt)
Cond(Expr* condition, Stmt* true_stmt, Stmt* false_stmt)
: condition_(condition) {
if (true_stmt) {
Block* b = dynamic_cast<Block*>(true_stmt);
@ -465,7 +464,7 @@ class TORCH_API Cond : public StmtNode<Cond> {
}
private:
const Expr* condition_;
Expr* condition_;
Block* true_stmt_ = nullptr;
Block* false_stmt_ = nullptr;
};
@ -586,12 +585,11 @@ class TORCH_API LoopOptions {
!is_parallel_;
}
void set_buffer_mapping(
const std::unordered_map<std::string, const Buf*>& map) {
void set_buffer_mapping(const std::unordered_map<std::string, Buf*>& map) {
map_input_to_tensor_bufs_ = map;
}
std::unordered_map<std::string, const Buf*> get_buffer_mapping() const {
std::unordered_map<std::string, Buf*> get_buffer_mapping() const {
return map_input_to_tensor_bufs_;
}
@ -599,18 +597,18 @@ class TORCH_API LoopOptions {
int gpu_block_index_{IDX_UNSET};
int gpu_thread_index_{IDX_UNSET};
bool is_parallel_{false};
std::unordered_map<std::string, const Buf*> map_input_to_tensor_bufs_;
std::unordered_map<std::string, Buf*> map_input_to_tensor_bufs_;
};
class TORCH_API For : public StmtNode<For> {
public:
const Var* var() const {
Var* var() const {
return var_;
}
const Expr* start() const {
Expr* start() const {
return start_;
}
const Expr* stop() const {
Expr* stop() const {
return stop_;
}
Block* body() const {
@ -641,7 +639,7 @@ class TORCH_API For : public StmtNode<For> {
return loop_options_;
}
For(const Var* var, const Expr* start, const Expr* stop, Stmt* body)
For(Var* var, Expr* start, Expr* stop, Stmt* body)
: var_(var), start_(start), stop_(stop) {
Block* b = dynamic_cast<Block*>(body);
if (!b) {
@ -651,11 +649,7 @@ class TORCH_API For : public StmtNode<For> {
set_parent(body_, this);
}
For(const Var* var,
const Expr* start,
const Expr* stop,
Stmt* body,
LoopOptions loop_options)
For(Var* var, Expr* start, Expr* stop, Stmt* body, LoopOptions loop_options)
: var_(var),
start_(start),
stop_(stop),
@ -694,7 +688,7 @@ class TORCH_API For : public StmtNode<For> {
return loop_options_.is_parallel();
}
void set_buffer_map(const std::unordered_map<std::string, const Buf*>& map) {
void set_buffer_map(const std::unordered_map<std::string, Buf*>& map) {
loop_options_.set_buffer_mapping(map);
}
@ -719,25 +713,25 @@ class TORCH_API For : public StmtNode<For> {
return body_;
}
const Expr* setStart(const Expr* start) {
Expr* setStart(Expr* start) {
start_ = start;
return start_;
}
const Expr* setStop(const Expr* stop) {
Expr* setStop(Expr* stop) {
stop_ = stop;
return stop_;
}
const Var* setVar(const Var* var) {
Var* setVar(Var* var) {
var_ = var;
return var_;
}
private:
const Var* var_;
const Expr* start_;
const Expr* stop_;
Var* var_;
Expr* start_;
Expr* stop_;
Block* body_;
LoopOptions loop_options_;
};
@ -749,34 +743,34 @@ class TORCH_API For : public StmtNode<For> {
class TORCH_API AtomicAdd : public StmtNode<AtomicAdd> {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AtomicAdd(const Buf* buf, std::vector<const Expr*> indices, const Expr* value)
AtomicAdd(Buf* buf, std::vector<Expr*> indices, Expr* value)
: buf_(buf), indices_(std::move(indices)), value_(value) {}
const Var* base_handle() const {
Var* base_handle() const {
return buf_->base_handle();
}
const Buf* buf() const {
Buf* buf() const {
return buf_;
}
const Expr* flat_index() const {
Expr* flat_index() const {
TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
return indices_[0];
}
const Expr* value() const {
Expr* value() const {
return value_;
}
const std::vector<const Expr*>& indices() const {
const std::vector<Expr*>& indices() const {
return indices_;
}
private:
const Buf* buf_;
std::vector<const Expr*> indices_;
const Expr* value_;
Buf* buf_;
std::vector<Expr*> indices_;
Expr* value_;
};
class TORCH_API SyncThreads : public StmtNode<SyncThreads> {
@ -811,7 +805,7 @@ class TORCH_API ExternalCall : public StmtNode<ExternalCall> {
const std::vector<BufHandle>& buf_args,
const std::vector<ExprHandle>& args);
const Buf* buf() const {
Buf* buf() const {
return buf_;
}
@ -819,30 +813,30 @@ class TORCH_API ExternalCall : public StmtNode<ExternalCall> {
return func_name_;
}
std::vector<const Buf*> buf_args() const {
std::vector<Buf*> buf_args() const {
return buf_args_;
}
std::vector<const Expr*> args() const {
std::vector<Expr*> args() const {
return args_;
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
ExternalCall(
const Buf* buf,
Buf* buf,
std::string func_name,
std::vector<const Buf*> buf_args,
std::vector<const Expr*> args)
std::vector<Buf*> buf_args,
std::vector<Expr*> args)
: buf_(buf),
func_name_(std::move(func_name)),
buf_args_(std::move(buf_args)),
args_(std::move(args)) {}
private:
const Buf* buf_;
Buf* buf_;
std::string func_name_;
std::vector<const Buf*> buf_args_;
std::vector<const Expr*> args_;
std::vector<Buf*> buf_args_;
std::vector<Expr*> args_;
};
} // namespace tensorexpr

View File

@ -10,11 +10,11 @@ namespace jit {
namespace tensorexpr {
Stmt* Tensor::constructStmt(
const std::vector<const Var*>& args,
const Expr* body,
const std::vector<const Expr*>& reduce_dims,
const std::vector<const Var*>& reduce_args) const {
std::vector<const Expr*> indices(args.begin(), args.end());
const std::vector<Var*>& args,
Expr* body,
const std::vector<Expr*>& reduce_dims,
const std::vector<Var*>& reduce_args) const {
std::vector<Expr*> indices(args.begin(), args.end());
Stmt* s = new Store(buf_, indices, body);
@ -25,10 +25,10 @@ Stmt* Tensor::constructStmt(
return s;
}
const Expr* init_expr = buf()->initializer();
Expr* init_expr = buf()->initializer();
if (reduce_ndim > 0) {
for (const auto i : c10::irange(reduce_ndim)) {
for (auto i : c10::irange(reduce_ndim)) {
// Going in reverse order: from innermost loop to the outermost
size_t dim_index = reduce_ndim - i - 1;
s = new For(
@ -40,7 +40,7 @@ Stmt* Tensor::constructStmt(
}
}
for (const auto i : c10::irange(ndim)) {
for (auto i : c10::irange(ndim)) {
// Going in reverse order: from innermost loop to the outermost
size_t dim_index = ndim - i - 1;
s = new For(args[dim_index], new IntImm(0), buf()->dim(dim_index), s);
@ -52,11 +52,11 @@ Tensor* Compute(
const std::string& name,
const std::vector<DimArg>& dim_args,
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
std::vector<const Expr*> dims;
std::vector<const Var*> args;
std::vector<Expr*> dims;
std::vector<Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
const Buf* buf = new Buf(name, dims, body->dtype());
Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
Buf* buf = new Buf(name, dims, body->dtype());
return new Tensor(buf, args, body);
}
@ -68,11 +68,11 @@ Tensor* Compute(
throw malformed_input("mismatch between body and arg size (1)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args;
std::vector<Expr*> dims;
std::vector<Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarHandle(args[0])).node();
const Buf* buf = new Buf(name, dims, body->dtype());
Expr* body = body_func(VarHandle(args[0])).node();
Buf* buf = new Buf(name, dims, body->dtype());
return new Tensor(buf, args, body);
}
@ -84,11 +84,11 @@ Tensor* Compute(
if (dim_args.size() != 2) {
throw malformed_input("mismatch between body and arg size (2)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args;
std::vector<Expr*> dims;
std::vector<Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
const Buf* buf = new Buf(name, dims, body->dtype());
Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
Buf* buf = new Buf(name, dims, body->dtype());
return new Tensor(buf, args, body);
}
@ -101,13 +101,13 @@ Tensor* Compute(
if (dim_args.size() != 3) {
throw malformed_input("mismatch between body and arg size (3)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args;
std::vector<Expr*> dims;
std::vector<Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body =
Expr* body =
body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
.node();
const Buf* buf = new Buf(name, dims, body->dtype());
Buf* buf = new Buf(name, dims, body->dtype());
return new Tensor(buf, args, body);
}
@ -122,16 +122,16 @@ Tensor* Compute(
if (dim_args.size() != 4) {
throw malformed_input("mismatch between body and arg size (4)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args;
std::vector<Expr*> dims;
std::vector<Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(
VarHandle(args[0]),
VarHandle(args[1]),
VarHandle(args[2]),
VarHandle(args[3]))
.node();
const Buf* buf = new Buf(name, dims, body->dtype());
Expr* body = body_func(
VarHandle(args[0]),
VarHandle(args[1]),
VarHandle(args[2]),
VarHandle(args[3]))
.node();
Buf* buf = new Buf(name, dims, body->dtype());
return new Tensor(buf, args, body);
}

View File

@ -15,25 +15,24 @@ namespace tensorexpr {
class TORCH_API Tensor : KernelScopedObject {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Tensor(const Buf* buf, const std::vector<const Var*>& args, const Expr* body)
: buf_(buf) {
Tensor(Buf* buf, const std::vector<Var*>& args, Expr* body) : buf_(buf) {
stmt_ = constructStmt(args, body, {}, {});
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Tensor(
const Buf* buf,
const std::vector<const Var*>& args,
const std::vector<const Expr*>& reduce_dims,
const std::vector<const Var*>& reduce_args,
const Expr* body)
Buf* buf,
const std::vector<Var*>& args,
const std::vector<Expr*>& reduce_dims,
const std::vector<Var*>& reduce_args,
Expr* body)
: buf_(buf) {
stmt_ = constructStmt(args, body, reduce_dims, reduce_args);
}
Tensor(const Buf* buf, Stmt* stmt) : buf_(buf), stmt_(stmt) {}
Tensor(Buf* buf, Stmt* stmt) : buf_(buf), stmt_(stmt) {}
const Buf* buf() const {
Buf* buf() const {
return buf_;
}
@ -48,12 +47,12 @@ class TORCH_API Tensor : KernelScopedObject {
private:
Stmt* constructStmt(
const std::vector<const Var*>& args,
const Expr* body,
const std::vector<const Expr*>& reduce_dims,
const std::vector<const Var*>& reduce_args) const;
const std::vector<Var*>& args,
Expr* body,
const std::vector<Expr*>& reduce_dims,
const std::vector<Var*>& reduce_args) const;
const Buf* buf_;
Buf* buf_;
Stmt* stmt_;
};
@ -96,7 +95,7 @@ class Placeholder {
explicit Placeholder(const std::vector<ExprHandle>& dims)
: Placeholder(BufHandle("_", dims, kFloat)) {}
const Buf* data() const {
Buf* data() const {
return data_;
}
BufHandle handle() const {
@ -108,10 +107,10 @@ class Placeholder {
int ndim() const {
return data_->ndim();
}
const Expr* dim(int index) const {
Expr* dim(int index) const {
return data_->dim(index);
}
std::vector<const Expr*> dims() const {
std::vector<Expr*> dims() const {
return data_->dims();
}
@ -130,8 +129,8 @@ class Placeholder {
}
private:
const Buf* data_;
std::vector<const Expr*> strides_;
Buf* data_;
std::vector<Expr*> strides_;
};
TORCH_API Tensor* Compute(
@ -164,12 +163,12 @@ TORCH_API Tensor* Compute(
inline void unpack_dim_args(
const std::vector<DimArg>& dim_args,
std::vector<const Expr*>* dims,
std::vector<const Var*>* vars) {
std::vector<Expr*>* dims,
std::vector<Var*>* vars) {
dims->clear();
vars->clear();
for (const DimArg& dim_arg : dim_args) {
const Expr* expr = dim_arg.dim().node();
Expr* expr = dim_arg.dim().node();
dims->push_back(expr);
vars->push_back(new Var(
dim_arg.name_hint(),
@ -187,22 +186,22 @@ Tensor* Reduce(
const BodyFunc& body_func,
const std::vector<DimArg>& reduce_args) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Expr*> dims;
std::vector<Expr*> dims;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Var*> vars;
std::vector<Var*> vars;
unpack_dim_args(dim_args, &dims, &vars);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Expr*> reduce_dims;
std::vector<Expr*> reduce_dims;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Var*> reduce_vars;
std::vector<Var*> reduce_vars;
unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars);
// If reduce_vars is empty, then it's not a reduction, but rather a simple
// copy
if (reduce_vars.empty()) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const Expr* body =
Expr* body =
Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(vars))
.node();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -211,22 +210,21 @@ Tensor* Reduce(
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Var*> all_vars;
std::vector<Var*> all_vars;
all_vars.insert(all_vars.end(), vars.begin(), vars.end());
all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end());
ExprHandle body =
Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(all_vars));
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Expr*> output_args(vars.begin(), vars.end());
std::vector<Expr*> output_args(vars.begin(), vars.end());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const Expr* init_expr = new Cast(
Expr* init_expr = new Cast(
body.dtype(), init_func(VarVectorToVarHandleVector(vars)).node());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Buf* func_result = new Buf(func_name, dims, body.dtype(), init_expr);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const ReduceOp* reduce_op =
reducer(func_result, body, output_args, reduce_vars);
ReduceOp* reduce_op = reducer(func_result, body, output_args, reduce_vars);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Tensor* t =
new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);

View File

@ -325,7 +325,7 @@ void initTensorExprBindings(PyObject* module) {
.def(py::init([](const std::vector<Stmt*>& stmts) {
return tensorexpr::Block::make(stmts);
}))
.def("__str__", [](const Stmt& self) {
.def("__str__", [](Stmt& self) {
std::stringstream ss;
ss << self;
return ss.str();
@ -343,7 +343,7 @@ void initTensorExprBindings(PyObject* module) {
py::class_<For, Stmt, std::unique_ptr<For, py::nodelete>>(te, "For")
.def(
"index_var",
[](const For& self) { return VarHandle(self.var()); },
[](For& self) { return VarHandle(self.var()); },
py::return_value_policy::reference)
.def("body", &For::body, py::return_value_policy::reference)
.def("set_parallel", &For::set_parallel)
@ -393,8 +393,8 @@ void initTensorExprBindings(PyObject* module) {
py::class_<LoopNest>(te, "LoopNest")
.def(py::init<const std::vector<Tensor*>&>())
.def(py::init([](Stmt* s, const std::vector<BufHandle>& bufs) {
std::unordered_set<const Buf*> buf_nodes;
for (const auto& buf : bufs) {
std::unordered_set<Buf*> buf_nodes;
for (auto& buf : bufs) {
buf_nodes.insert(buf.node());
}
return std::make_unique<LoopNest>(s, buf_nodes);
@ -427,7 +427,7 @@ void initTensorExprBindings(PyObject* module) {
py::return_value_policy::reference)
.def(
"get_enclosing_loopnest",
[](const LoopNest& self, const Stmt* s) {
[](const LoopNest& self, Stmt* s) {
return self.getEnclosingLoopNest(s);
},
py::return_value_policy::reference)
@ -451,9 +451,7 @@ void initTensorExprBindings(PyObject* module) {
py::return_value_policy::reference)
.def(
"get_parent_loop",
[](const LoopNest& self, const Stmt* s) {
return self.getParentLoop(s);
},
[](const LoopNest& self, Stmt* s) { return self.getParentLoop(s); },
py::return_value_policy::reference)
.def_static(
"get_loop_stmts_in_loopnest",
@ -566,7 +564,7 @@ void initTensorExprBindings(PyObject* module) {
[](const BufHandle& producer,
const std::string& name,
Stmt* consumer) {
std::pair<const Buf*, Stmt*> ret =
std::pair<Buf*, Stmt*> ret =
LoopNest::cacheAccesses(producer.node(), name, consumer);
return std::make_pair(BufHandle(ret.first), ret.second);
},
@ -679,7 +677,7 @@ void initTensorExprBindings(PyObject* module) {
std::unordered_map<std::string, NNCLoweringFunction>
custom_lowerings_str) {
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings;
for (const auto& kv : custom_lowerings_str) {
for (auto& kv : custom_lowerings_str) {
custom_lowerings[c10::Symbol::fromQualString(kv.first)] = kv.second;
}
return std::make_unique<TensorExprKernel>(g, custom_lowerings);

View File

@ -8,7 +8,7 @@ namespace torch {
namespace jit {
namespace tensorexpr {
const std::string& UniqueNameManager::get_unique_name(const Var* v) {
const std::string& UniqueNameManager::get_unique_name(Var* v) {
// Find if we have already encountered this variable.
auto iter = unique_name_mapping_.find(v);
if (iter != unique_name_mapping_.end()) {

View File

@ -13,7 +13,7 @@ namespace tensorexpr {
class VarHandle;
class Var;
using VarNameMap = std::unordered_map<const Var*, std::string>;
using VarNameMap = std::unordered_map<Var*, std::string>;
// A manager to get unique names from vars.
// It starts with the name hints of the var and append "_" + $counter until it
@ -23,7 +23,7 @@ class TORCH_API UniqueNameManager {
public:
const std::string& get_unique_name(const VarHandle& v);
const std::string& get_unique_name(const Var* v);
const std::string& get_unique_name(Var* v);
private:
friend class ScopedVarName;

View File

@ -13,15 +13,15 @@ namespace torch {
namespace jit {
namespace tensorexpr {
using VarMapping = std::vector<std::pair<const Var*, const Expr*>>;
using VarMapping = std::vector<std::pair<Var*, Expr*>>;
class VarSubMutator : public IRMutator {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
VarSubMutator(const VarMapping& var_mapping) {
for (const auto& entry : var_mapping) {
const Var* key_var = entry.first;
const Expr* value = entry.second;
for (auto& entry : var_mapping) {
Var* key_var = entry.first;
Expr* value = entry.second;
if (!key_var) {
throw malformed_input("missing key in VarSubMutator");
}
@ -29,7 +29,7 @@ class VarSubMutator : public IRMutator {
}
}
const Expr* mutate(const Var* var) override {
Expr* mutate(Var* var) override {
auto iter = var_mapping_.find(var);
if (iter == var_mapping_.end()) {
return var;
@ -37,14 +37,14 @@ class VarSubMutator : public IRMutator {
return iter->second;
}
const Expr* mutate(const ReduceOp* var) override {
Expr* mutate(ReduceOp* var) override {
auto body = var->body()->accept_mutator(this);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<const Var*> new_inner;
std::vector<Var*> new_inner;
for (auto* v : var->reduce_args()) {
const Expr* e = v->accept_mutator(this);
if (const Var* new_var = dynamic_cast<const Var*>(e)) {
Expr* e = v->accept_mutator(this);
if (Var* new_var = dynamic_cast<Var*>(e)) {
new_inner.push_back(new_var);
} else {
VarFinder varFinder;
@ -58,7 +58,7 @@ class VarSubMutator : public IRMutator {
}
private:
std::unordered_map<const Var*, const Expr*> var_mapping_;
std::unordered_map<Var*, Expr*> var_mapping_;
};
} // namespace tensorexpr