mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nnc] Removed const from all fields in IR. (#62336)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62336 This PR was generated by removing `const` for all types of nodes in NNC IR, and fixing compilation errors that were the result of this change. This is the first step in making all NNC mutations in-place. Test Plan: Imported from OSS Reviewed By: iramazanli Differential Revision: D30049829 Pulled By: navahgar fbshipit-source-id: ed14e2d2ca0559ffc0b92ac371f405579c85dd63
This commit is contained in:
committed by
Facebook GitHub Bot
parent
474d7ec43b
commit
59dd12042e
@ -14,8 +14,8 @@ using namespace torch::jit::tensorexpr;
|
||||
|
||||
TEST(CppPrinter, AllocateOnStackThenFree) {
|
||||
KernelScope kernel_scope;
|
||||
std::vector<const Expr*> dims = {new IntImm(2), new IntImm(3)};
|
||||
const Buf* buf = new Buf("x", dims, kInt);
|
||||
std::vector<Expr*> dims = {new IntImm(2), new IntImm(3)};
|
||||
Buf* buf = new Buf("x", dims, kInt);
|
||||
Allocate* alloc = new Allocate(buf);
|
||||
Free* free = new Free(buf);
|
||||
Block* block = Block::make({alloc, free});
|
||||
@ -33,9 +33,8 @@ TEST(CppPrinter, AllocateOnStackThenFree) {
|
||||
|
||||
TEST(CppPrinter, AllocateOnHeapThenFree) {
|
||||
KernelScope kernel_scope;
|
||||
std::vector<const Expr*> dims = {
|
||||
new IntImm(20), new IntImm(50), new IntImm(3)};
|
||||
const Buf* buf = new Buf("y", dims, kLong);
|
||||
std::vector<Expr*> dims = {new IntImm(20), new IntImm(50), new IntImm(3)};
|
||||
Buf* buf = new Buf("y", dims, kLong);
|
||||
Allocate* alloc = new Allocate(buf);
|
||||
Free* free = new Free(buf);
|
||||
Block* block = Block::make({alloc, free});
|
||||
|
@ -705,7 +705,7 @@ TEST(Cuda, SharedMemReduce_1_CUDA) {
|
||||
VarHandle n("n", kInt);
|
||||
|
||||
std::vector<Stmt*> block;
|
||||
std::vector<const Expr*> dims;
|
||||
std::vector<Expr*> dims;
|
||||
dims.push_back(ExprHandle(N).node());
|
||||
BufHandle c{new Buf("c", dims, kFloat)};
|
||||
{
|
||||
|
@ -313,13 +313,13 @@ TEST(Expr, IntrinsicsDtypes) {
|
||||
|
||||
TEST(Expr, Substitute01) {
|
||||
KernelScope kernel_scope;
|
||||
const Var* x = new Var("x", kFloat);
|
||||
const Var* y = new Var("y", kFloat);
|
||||
const Expr* e = new Mul(new Sub(x, new FloatImm(1.0f)), new Add(x, y));
|
||||
Var* x = new Var("x", kFloat);
|
||||
Var* y = new Var("y", kFloat);
|
||||
Expr* e = new Mul(new Sub(x, new FloatImm(1.0f)), new Add(x, y));
|
||||
|
||||
const Var* z = new Var("z", kFloat);
|
||||
const Expr* e2 = Substitute(e, {{x, new Add(z, new FloatImm(5.0f))}});
|
||||
const Expr* e2_ref = new Mul(
|
||||
Var* z = new Var("z", kFloat);
|
||||
Expr* e2 = Substitute(e, {{x, new Add(z, new FloatImm(5.0f))}});
|
||||
Expr* e2_ref = new Mul(
|
||||
new Sub(new Add(z, new FloatImm(5.0f)), new FloatImm(1.0f)),
|
||||
new Add(new Add(z, new FloatImm(5.0f)), y));
|
||||
std::ostringstream oss;
|
||||
@ -663,7 +663,7 @@ void testStmtClone() {
|
||||
// original statement hasn't changed while the cloned one has.
|
||||
Stmt* body_addition = a_buf.store({index}, 33);
|
||||
Block* cloned_body =
|
||||
static_cast<Block*>(static_cast<const For*>(cloned_loop)->body());
|
||||
static_cast<Block*>(static_cast<For*>(cloned_loop)->body());
|
||||
cloned_body->append_stmt(body_addition);
|
||||
|
||||
std::vector<int> orig_loop_results_after_mutation(N);
|
||||
|
@ -18,8 +18,8 @@ using namespace torch::jit::tensorexpr;
|
||||
|
||||
TEST(IRVerifier, BitwiseOps) {
|
||||
KernelScope kernel_scope;
|
||||
const Var* X = new Var("x", kInt);
|
||||
const Var* Y = new Var("y", kFloat);
|
||||
Var* X = new Var("x", kInt);
|
||||
Var* Y = new Var("y", kFloat);
|
||||
{
|
||||
auto a = new And(X, Y);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
@ -49,8 +49,8 @@ TEST(IRVerifier, BitwiseOps) {
|
||||
|
||||
TEST(IRVerifier, CompareSelect) {
|
||||
KernelScope kernel_scope;
|
||||
const Expr* X = new IntImm(1);
|
||||
const Expr* Y = new FloatImm(3.14f);
|
||||
Expr* X = new IntImm(1);
|
||||
Expr* Y = new FloatImm(3.14f);
|
||||
{
|
||||
auto a = new CompareSelect(X, X, X, Y, kEQ);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
@ -65,8 +65,8 @@ TEST(IRVerifier, CompareSelect) {
|
||||
|
||||
TEST(IRVerifier, Ramp) {
|
||||
KernelScope kernel_scope;
|
||||
const Var* I = new Var("i", kInt);
|
||||
const Var* J = new Var("j", kFloat);
|
||||
Var* I = new Var("i", kInt);
|
||||
Var* J = new Var("j", kFloat);
|
||||
{
|
||||
auto a = new Ramp(I, J, 4);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto,clang-analyzer-cplusplus.NewDeleteLeaks)
|
||||
@ -76,10 +76,10 @@ TEST(IRVerifier, Ramp) {
|
||||
|
||||
TEST(IRVerifier, Load) {
|
||||
KernelScope kernel_scope;
|
||||
const Var* I = new Var("i", kInt);
|
||||
const Var* J = new Var("j", kLong);
|
||||
const Var* K = new Var("k", kFloat);
|
||||
const Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
|
||||
Var* I = new Var("i", kInt);
|
||||
Var* J = new Var("j", kLong);
|
||||
Var* K = new Var("k", kFloat);
|
||||
Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
|
||||
{
|
||||
// Indices with different int dtypes (kInt, kLong) are ok
|
||||
auto a = new Load(B, {I, J});
|
||||
@ -103,9 +103,9 @@ TEST(IRVerifier, Load) {
|
||||
|
||||
TEST(IRVerifier, IfThenElse) {
|
||||
KernelScope kernel_scope;
|
||||
const Var* I = new Var("i", kInt);
|
||||
const Var* J = new Var("j", kLong);
|
||||
const Var* K = new Var("k", kFloat);
|
||||
Var* I = new Var("i", kInt);
|
||||
Var* J = new Var("j", kLong);
|
||||
Var* K = new Var("k", kFloat);
|
||||
{
|
||||
// Condition must be integral
|
||||
auto a = new IfThenElse(K, I, I);
|
||||
@ -128,8 +128,8 @@ TEST(IRVerifier, IfThenElse) {
|
||||
|
||||
TEST(IRVerifier, For) {
|
||||
KernelScope kernel_scope;
|
||||
const Var* I = new Var("i", kInt);
|
||||
const Var* J = new Var("j", kInt);
|
||||
Var* I = new Var("i", kInt);
|
||||
Var* J = new Var("j", kInt);
|
||||
Stmt* body = new Block({});
|
||||
{
|
||||
// Can't have nullptr as a Var
|
||||
@ -141,8 +141,8 @@ TEST(IRVerifier, For) {
|
||||
|
||||
TEST(IRVerifier, Block) {
|
||||
KernelScope kernel_scope;
|
||||
const Var* I = new Var("i", kInt);
|
||||
const Buf* B = new Buf("B", {new IntImm(10)}, kInt);
|
||||
Var* I = new Var("i", kInt);
|
||||
Buf* B = new Buf("B", {new IntImm(10)}, kInt);
|
||||
{
|
||||
Stmt* store = new Store(B, {I}, I);
|
||||
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
|
||||
@ -158,10 +158,10 @@ TEST(IRVerifier, Block) {
|
||||
|
||||
TEST(IRVerifier, Store) {
|
||||
KernelScope kernel_scope;
|
||||
const Var* I = new Var("i", kInt);
|
||||
const Var* J = new Var("j", kLong);
|
||||
const Var* K = new Var("k", kFloat);
|
||||
const Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
|
||||
Var* I = new Var("i", kInt);
|
||||
Var* J = new Var("j", kLong);
|
||||
Var* K = new Var("k", kFloat);
|
||||
Buf* B = new Buf("b", {new IntImm(10), new IntImm(20)}, kFloat);
|
||||
{
|
||||
// Indices with different int dtypes (kInt, kLong) are ok
|
||||
auto a = new Store(B, {I, J}, K);
|
||||
|
@ -2246,7 +2246,7 @@ class LoopOrderHelper : public IRVisitor {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-explicit--functions,modernize-use-override)
|
||||
void visit(const For* v) {
|
||||
void visit(For* v) {
|
||||
ordering << v->var()->name_hint() << ",";
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
@ -815,9 +815,9 @@ TEST(MemDependency, MemDependencyCheckerLoopBounds) {
|
||||
// much.
|
||||
auto history = analyzer.getHistory();
|
||||
ASSERT_EQ(history.size(), 10);
|
||||
const Var* aVar = a.node()->base_handle();
|
||||
const Var* bVar = b.node()->base_handle();
|
||||
const Var* cVar = c.node()->base_handle();
|
||||
Var* aVar = a.node()->base_handle();
|
||||
Var* bVar = b.node()->base_handle();
|
||||
Var* cVar = c.node()->base_handle();
|
||||
|
||||
// The first access is the input A.
|
||||
ASSERT_EQ(history[0]->type(), AccessType::Input);
|
||||
@ -989,8 +989,8 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) {
|
||||
// Now let's look at the bounds of each access.
|
||||
auto history = analyzer.getHistory();
|
||||
ASSERT_EQ(history.size(), 12);
|
||||
const Var* aVar = a.node()->base_handle();
|
||||
const Var* bVar = b.node()->base_handle();
|
||||
Var* aVar = a.node()->base_handle();
|
||||
Var* bVar = b.node()->base_handle();
|
||||
|
||||
// The first access is the input A.
|
||||
ASSERT_EQ(history[0]->type(), AccessType::Input);
|
||||
@ -3119,7 +3119,7 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
|
||||
history_before[i]->bounds(), history_after[i]->bounds()));
|
||||
} else {
|
||||
ASSERT_EQ(history_after[i]->bounds().size(), 1);
|
||||
const Expr* flat_bounds = new IntImm(1);
|
||||
Expr* flat_bounds = new IntImm(1);
|
||||
|
||||
for (auto& b : history_before[i]->bounds()) {
|
||||
flat_bounds = new Mul(flat_bounds, new Add(b.end, new IntImm(1)));
|
||||
@ -3129,7 +3129,7 @@ TEST(MemDependency, MemDependencyCheckerComputeGEMM) {
|
||||
}
|
||||
|
||||
flat_bounds = IRSimplifier::simplify(flat_bounds);
|
||||
const Expr* after_bounds = IRSimplifier::simplify(
|
||||
Expr* after_bounds = IRSimplifier::simplify(
|
||||
new Add(history_after[i]->bounds()[0].end, new IntImm(1)));
|
||||
ASSERT_TRUE(exprEquals(flat_bounds, after_bounds));
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ TEST(Simplify, ConstantFoldWithVar) {
|
||||
ExprHandle body = x * (ExprHandle(2) + ExprHandle(4));
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(body);
|
||||
const Mul* root = newF.AsNode<Mul>();
|
||||
Mul* root = newF.AsNode<Mul>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
ASSERT_NE(dynamic_cast<const IntImm*>(root->lhs()), nullptr);
|
||||
|
||||
@ -163,7 +163,7 @@ TEST(Simplify, ConstantFoldWithVar) {
|
||||
ExprHandle body = x * (ExprHandle(2.f) + ExprHandle(4.f));
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(body);
|
||||
const Mul* root = newF.AsNode<Mul>();
|
||||
Mul* root = newF.AsNode<Mul>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
ASSERT_NE(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
|
||||
|
||||
@ -296,7 +296,7 @@ TEST(Simplify, UnFoldableExpr) {
|
||||
ExprHandle body = (ExprHandle(3) * x) + (ExprHandle(5) * y);
|
||||
|
||||
ExprHandle newF = IRSimplifier::simplify(body);
|
||||
const Add* root = newF.AsNode<Add>();
|
||||
Add* root = newF.AsNode<Add>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
ASSERT_EQ(dynamic_cast<const FloatImm*>(root->lhs()), nullptr);
|
||||
ASSERT_EQ(dynamic_cast<const FloatImm*>(root->rhs()), nullptr);
|
||||
@ -334,7 +334,7 @@ TEST(Simplify, HashEquivalence) {
|
||||
VarHandle y("y", kFloat);
|
||||
ExprHandle f = (x * y) + (x * y);
|
||||
|
||||
const Add* root = f.AsNode<Add>();
|
||||
Add* root = f.AsNode<Add>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
|
||||
HashProvider hasher;
|
||||
@ -370,7 +370,7 @@ TEST(Simplify, HashEquivalenceRand) {
|
||||
ExprHandle f =
|
||||
Intrinsics::make(kRand, kFloat) + Intrinsics::make(kRand, kInt);
|
||||
|
||||
const Add* root = f.AsNode<Add>();
|
||||
Add* root = f.AsNode<Add>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
|
||||
HashProvider hasher;
|
||||
@ -415,7 +415,7 @@ TEST(Simplify, HashDifferenceTypes) {
|
||||
KernelScope kernel_scope;
|
||||
|
||||
HashProvider hasher;
|
||||
std::vector<const Expr*> immediates;
|
||||
std::vector<Expr*> immediates;
|
||||
|
||||
immediates.push_back(new DoubleImm(1));
|
||||
immediates.push_back(new FloatImm(1));
|
||||
@ -546,9 +546,9 @@ TEST(Simplify, SimplifyAdd) {
|
||||
ExprHandle body = (ExprHandle(2) + x) + ExprHandle(4);
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Add* root = simplified.AsNode<Add>();
|
||||
Add* root = simplified.AsNode<Add>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
const Var* lhs = dynamic_cast<const Var*>(root->lhs());
|
||||
Var* lhs = dynamic_cast<Var*>(root->lhs());
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
ASSERT_EQ(lhs->name_hint(), "x");
|
||||
const IntImm* rhs = dynamic_cast<const IntImm*>(root->rhs());
|
||||
@ -563,12 +563,12 @@ TEST(Simplify, SimplifySub) {
|
||||
ExprHandle body = (ExprHandle(2) - x) - ExprHandle(4);
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Sub* root = simplified.AsNode<Sub>();
|
||||
Sub* root = simplified.AsNode<Sub>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
ASSERT_EQ(lhs->value(), -2.f);
|
||||
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
|
||||
Var* rhs = dynamic_cast<Var*>(root->rhs());
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
ASSERT_EQ(rhs->name_hint(), "x");
|
||||
}
|
||||
@ -594,12 +594,12 @@ TEST(Simplify, SimplifyMultiTerm) {
|
||||
(ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Mul* root = simplified.AsNode<Mul>();
|
||||
Mul* root = simplified.AsNode<Mul>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
const IntImm* lhs = dynamic_cast<const IntImm*>(root->lhs());
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
ASSERT_EQ(lhs->value(), 2);
|
||||
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
|
||||
Var* rhs = dynamic_cast<Var*>(root->rhs());
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
ASSERT_EQ(rhs->name_hint(), "x");
|
||||
}
|
||||
@ -612,12 +612,12 @@ TEST(Simplify, SimplifyCasts) {
|
||||
(ExprHandle(2) * ((ExprHandle(3) * x)) - (x * ExprHandle(4)));
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Mul* root = simplified.AsNode<Mul>();
|
||||
Mul* root = simplified.AsNode<Mul>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
const LongImm* lhs = dynamic_cast<const LongImm*>(root->lhs());
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
ASSERT_EQ(lhs->value(), 2);
|
||||
const Var* rhs = dynamic_cast<const Var*>(root->rhs());
|
||||
Var* rhs = dynamic_cast<Var*>(root->rhs());
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
ASSERT_EQ(rhs->name_hint(), "x");
|
||||
}
|
||||
@ -629,7 +629,7 @@ TEST(Simplify, SimplifyEliminatesNoOps) {
|
||||
ExprHandle body = (x + ExprHandle(0)) * 1;
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
const Var* root = simplified.AsNode<Var>();
|
||||
Var* root = simplified.AsNode<Var>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
ASSERT_EQ(root->name_hint(), "x");
|
||||
}
|
||||
@ -643,16 +643,16 @@ TEST(Simplify, SimplifyMultiVar) {
|
||||
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
const Add* root = simplified.AsNode<Add>();
|
||||
Add* root = simplified.AsNode<Add>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
const Mul* lhs = dynamic_cast<const Mul*>(root->lhs());
|
||||
Mul* lhs = dynamic_cast<Mul*>(root->lhs());
|
||||
ASSERT_NE(lhs, nullptr);
|
||||
const Var* varX = dynamic_cast<const Var*>(lhs->rhs());
|
||||
Var* varX = dynamic_cast<Var*>(lhs->rhs());
|
||||
ASSERT_NE(varX, nullptr);
|
||||
ASSERT_EQ(varX->name_hint(), "y");
|
||||
const Mul* rhs = dynamic_cast<const Mul*>(root->rhs());
|
||||
Mul* rhs = dynamic_cast<Mul*>(root->rhs());
|
||||
ASSERT_NE(rhs, nullptr);
|
||||
const Var* varY = dynamic_cast<const Var*>(rhs->rhs());
|
||||
Var* varY = dynamic_cast<Var*>(rhs->rhs());
|
||||
ASSERT_NE(varY, nullptr);
|
||||
ASSERT_EQ(varY->name_hint(), "x");
|
||||
}
|
||||
@ -665,7 +665,7 @@ TEST(Simplify, DISABLED_SimplifyReorderings) {
|
||||
ExprHandle body = x + 2 + y;
|
||||
ExprHandle simplified = IRSimplifier::simplify(body);
|
||||
|
||||
const Add* root = simplified.AsNode<Add>();
|
||||
Add* root = simplified.AsNode<Add>();
|
||||
ASSERT_NE(root, nullptr);
|
||||
|
||||
IS_NODE_WITH_NAME(Add, root->lhs(), rhs);
|
||||
|
@ -127,13 +127,13 @@ int main(int argc, char* argv[]) {
|
||||
// Let's start with defining a domain. We do this by creating a Buf object.
|
||||
|
||||
// First, let's specify the sizes:
|
||||
std::vector<const Expr*> dims = {
|
||||
std::vector<Expr*> dims = {
|
||||
new IntImm(64), new IntImm(32)}; // IntImm stands for Integer Immediate
|
||||
// and represents an integer constant
|
||||
|
||||
// Now we can create a Buf object by providing a name, dimensions, and a
|
||||
// data type of the elements:
|
||||
const Buf* buf = new Buf("X", dims, kInt);
|
||||
Buf* buf = new Buf("X", dims, kInt);
|
||||
|
||||
// Next we need to spefify the computation. We can do that by either
|
||||
// constructing a complete tensor statement for it (statements are
|
||||
@ -144,9 +144,9 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
// Let's define two variables, i and j - they will be axis in our
|
||||
// computation.
|
||||
const Var* i = new Var("i", kInt);
|
||||
const Var* j = new Var("j", kInt);
|
||||
std::vector<const Var*> args = {i, j};
|
||||
Var* i = new Var("i", kInt);
|
||||
Var* j = new Var("j", kInt);
|
||||
std::vector<Var*> args = {i, j};
|
||||
|
||||
// Now we can define the body of the tensor computation using these
|
||||
// variables. What this means is that values in our tensor are:
|
||||
|
@ -19,7 +19,7 @@ class HasRand : public IRVisitor {
|
||||
}
|
||||
|
||||
private:
|
||||
void visit(const Intrinsics* v) override {
|
||||
void visit(Intrinsics* v) override {
|
||||
if (v->op_type() == IntrinsicsOp::kRand) {
|
||||
has_rand_ = true;
|
||||
} else {
|
||||
@ -34,18 +34,18 @@ template <typename Node>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
class NodeFinder : public IRVisitor {
|
||||
public:
|
||||
void visit(const Node* v) override {
|
||||
void visit(Node* v) override {
|
||||
nodes.push_back((Node*)v);
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
static std::vector<Node*> find(const Stmt* s) {
|
||||
static std::vector<Node*> find(Stmt* s) {
|
||||
NodeFinder<Node> nf;
|
||||
s->accept(&nf);
|
||||
return nf.nodes;
|
||||
}
|
||||
|
||||
static std::vector<Node*> find(const Expr* e) {
|
||||
static std::vector<Node*> find(Expr* e) {
|
||||
NodeFinder<Node> nf;
|
||||
e->accept(&nf);
|
||||
return nf.nodes;
|
||||
@ -57,108 +57,108 @@ class NodeFinder : public IRVisitor {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
class VarFinder : public IRVisitor {
|
||||
public:
|
||||
void visit(const Var* v) override {
|
||||
void visit(Var* v) override {
|
||||
vars_.insert(v);
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
static std::unordered_set<const Var*> find(Stmt* s) {
|
||||
static std::unordered_set<Var*> find(Stmt* s) {
|
||||
VarFinder nf;
|
||||
s->accept(&nf);
|
||||
return nf.vars();
|
||||
}
|
||||
|
||||
static std::unordered_set<const Var*> find(const Expr* e) {
|
||||
static std::unordered_set<Var*> find(Expr* e) {
|
||||
VarFinder nf;
|
||||
e->accept(&nf);
|
||||
return nf.vars();
|
||||
}
|
||||
|
||||
const std::unordered_set<const Var*>& vars() {
|
||||
const std::unordered_set<Var*>& vars() {
|
||||
return vars_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_set<const Var*> vars_;
|
||||
std::unordered_set<Var*> vars_;
|
||||
};
|
||||
|
||||
class BufFinder : public IRVisitor {
|
||||
public:
|
||||
void visit(const Buf* v) override {
|
||||
void visit(Buf* v) override {
|
||||
bufs_.insert(v);
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
static std::unordered_set<const Buf*> find(Stmt* s) {
|
||||
static std::unordered_set<Buf*> find(Stmt* s) {
|
||||
BufFinder nf;
|
||||
s->accept(&nf);
|
||||
return nf.bufs();
|
||||
}
|
||||
|
||||
static std::unordered_set<const Buf*> find(const Expr* e) {
|
||||
static std::unordered_set<Buf*> find(Expr* e) {
|
||||
BufFinder nf;
|
||||
e->accept(&nf);
|
||||
return nf.bufs();
|
||||
}
|
||||
|
||||
const std::unordered_set<const Buf*>& bufs() {
|
||||
const std::unordered_set<Buf*>& bufs() {
|
||||
return bufs_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_set<const Buf*> bufs_;
|
||||
std::unordered_set<Buf*> bufs_;
|
||||
};
|
||||
|
||||
// Finds all kinds of write operations to the provided Buf.
|
||||
class WritesToBuf : public IRVisitor {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
WritesToBuf(const Buf* target) : target_(target) {}
|
||||
WritesToBuf(Buf* target) : target_(target) {}
|
||||
|
||||
std::vector<const Stmt*> writes() {
|
||||
std::vector<Stmt*> writes() {
|
||||
return writes_;
|
||||
}
|
||||
|
||||
static std::vector<const Stmt*> find(Stmt* s, const Buf* b) {
|
||||
static std::vector<Stmt*> find(Stmt* s, Buf* b) {
|
||||
WritesToBuf finder(b);
|
||||
s->accept(&finder);
|
||||
return finder.writes();
|
||||
}
|
||||
|
||||
private:
|
||||
void visit(const Store* v) override {
|
||||
void visit(Store* v) override {
|
||||
if (v->buf() == target_) {
|
||||
writes_.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const AtomicAdd* v) override {
|
||||
void visit(AtomicAdd* v) override {
|
||||
if (v->buf() == target_) {
|
||||
writes_.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
const Buf* target_;
|
||||
std::vector<const Stmt*> writes_;
|
||||
Buf* target_;
|
||||
std::vector<Stmt*> writes_;
|
||||
};
|
||||
|
||||
class StmtsReadingBuf : public IRVisitor {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
StmtsReadingBuf(const Buf* target) : target_(target) {}
|
||||
StmtsReadingBuf(Buf* target) : target_(target) {}
|
||||
|
||||
std::vector<const Stmt*> reads() {
|
||||
std::vector<Stmt*> reads() {
|
||||
return reads_;
|
||||
}
|
||||
|
||||
static std::vector<const Stmt*> find(Stmt* s, const Buf* b) {
|
||||
static std::vector<Stmt*> find(Stmt* s, Buf* b) {
|
||||
StmtsReadingBuf finder(b);
|
||||
s->accept(&finder);
|
||||
return finder.reads();
|
||||
}
|
||||
|
||||
private:
|
||||
bool readsBuffer(const Stmt* s) {
|
||||
bool readsBuffer(Stmt* s) {
|
||||
auto loads = NodeFinder<Load>::find(s);
|
||||
for (auto l : loads) {
|
||||
if (l->buf() == target_) {
|
||||
@ -168,40 +168,40 @@ class StmtsReadingBuf : public IRVisitor {
|
||||
return false;
|
||||
}
|
||||
|
||||
void visit(const Store* v) override {
|
||||
void visit(Store* v) override {
|
||||
if (readsBuffer(v)) {
|
||||
reads_.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const Let* v) override {
|
||||
void visit(Let* v) override {
|
||||
if (readsBuffer(v)) {
|
||||
reads_.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const Cond* v) override {
|
||||
void visit(Cond* v) override {
|
||||
if (readsBuffer(v)) {
|
||||
reads_.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const AtomicAdd* v) override {
|
||||
void visit(AtomicAdd* v) override {
|
||||
if (readsBuffer(v)) {
|
||||
reads_.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
const Buf* target_;
|
||||
std::vector<const Stmt*> reads_;
|
||||
Buf* target_;
|
||||
std::vector<Stmt*> reads_;
|
||||
};
|
||||
|
||||
// Traverses the IR to determine if a particular Var is modified within it.
|
||||
class ModifiesVarChecker : public IRVisitor {
|
||||
public:
|
||||
ModifiesVarChecker(const Var* v) : var_(v) {}
|
||||
ModifiesVarChecker(Var* v) : var_(v) {}
|
||||
|
||||
static bool check(const Stmt* s, const Var* v) {
|
||||
static bool check(Stmt* s, Var* v) {
|
||||
ModifiesVarChecker checker(v);
|
||||
s->accept(&checker);
|
||||
return checker.found();
|
||||
@ -212,7 +212,7 @@ class ModifiesVarChecker : public IRVisitor {
|
||||
}
|
||||
|
||||
private:
|
||||
void visit(const Store* v) override {
|
||||
void visit(Store* v) override {
|
||||
if (v->buf()->base_handle() == var_) {
|
||||
found_ = true;
|
||||
return;
|
||||
@ -220,7 +220,7 @@ class ModifiesVarChecker : public IRVisitor {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void visit(const AtomicAdd* v) override {
|
||||
void visit(AtomicAdd* v) override {
|
||||
if (v->buf()->base_handle() == var_) {
|
||||
found_ = true;
|
||||
return;
|
||||
@ -228,7 +228,7 @@ class ModifiesVarChecker : public IRVisitor {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void visit(const Let* v) override {
|
||||
void visit(Let* v) override {
|
||||
if (v->var() == var_) {
|
||||
found_ = true;
|
||||
return;
|
||||
@ -236,7 +236,7 @@ class ModifiesVarChecker : public IRVisitor {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void visit(const For* v) override {
|
||||
void visit(For* v) override {
|
||||
if (v->var() == var_) {
|
||||
found_ = true;
|
||||
return;
|
||||
@ -244,7 +244,7 @@ class ModifiesVarChecker : public IRVisitor {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
const Var* var_;
|
||||
Var* var_;
|
||||
bool found_{false};
|
||||
};
|
||||
|
||||
@ -252,26 +252,26 @@ class ModifiesVarChecker : public IRVisitor {
|
||||
// It creates a map of multi dim buffers and their flat verions
|
||||
class CreateBufferMap : public IRVisitor {
|
||||
public:
|
||||
const std::unordered_map<std::string, const Buf*>& getBufferMap() const {
|
||||
const std::unordered_map<std::string, Buf*>& getBufferMap() const {
|
||||
return map_input_to_tensor_bufs_;
|
||||
}
|
||||
|
||||
private:
|
||||
void visit(const Store* v) override {
|
||||
auto load_node = dynamic_cast<const Load*>(v->value());
|
||||
void visit(Store* v) override {
|
||||
auto load_node = dynamic_cast<Load*>(v->value());
|
||||
if (load_node) {
|
||||
auto t_buf = load_node->buf();
|
||||
map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf());
|
||||
} else {
|
||||
auto add_node = dynamic_cast<const Add*>(v->value());
|
||||
auto mul_node = dynamic_cast<const Mul*>(v->value());
|
||||
auto add_node = dynamic_cast<Add*>(v->value());
|
||||
auto mul_node = dynamic_cast<Mul*>(v->value());
|
||||
// This means for now, v->value() can be Add or Mul
|
||||
TORCH_INTERNAL_ASSERT((add_node || mul_node));
|
||||
map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), v->buf());
|
||||
}
|
||||
v->value()->accept(this);
|
||||
}
|
||||
std::unordered_map<std::string, const Buf*> map_input_to_tensor_bufs_;
|
||||
std::unordered_map<std::string, Buf*> map_input_to_tensor_bufs_;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -32,8 +32,7 @@ std::string blockDtypeCppString(const Dtype& dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
bool BlockAnalysis::areBufsInMap(
|
||||
const std::unordered_set<const Buf*>& bufs) const {
|
||||
bool BlockAnalysis::areBufsInMap(const std::unordered_set<Buf*>& bufs) const {
|
||||
for (auto const& arg : bufs) {
|
||||
auto got = map_input_to_tensor_bufs_.find(arg->name_hint());
|
||||
if (got == map_input_to_tensor_bufs_.end()) {
|
||||
@ -43,7 +42,7 @@ bool BlockAnalysis::areBufsInMap(
|
||||
return true;
|
||||
}
|
||||
|
||||
const Buf* BlockAnalysis::getMultiDimBuf(const Buf* buf) const {
|
||||
Buf* BlockAnalysis::getMultiDimBuf(Buf* buf) const {
|
||||
auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
|
||||
if (input_ != map_input_to_tensor_bufs_.end()) {
|
||||
return input_->second;
|
||||
@ -52,7 +51,7 @@ const Buf* BlockAnalysis::getMultiDimBuf(const Buf* buf) const {
|
||||
}
|
||||
}
|
||||
|
||||
std::string BlockAnalysis::getInputName(const Buf* buf) const {
|
||||
std::string BlockAnalysis::getInputName(Buf* buf) const {
|
||||
auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint());
|
||||
if (input_ != map_input_to_tensor_bufs_.end()) {
|
||||
return input_->second->name_hint();
|
||||
@ -61,23 +60,23 @@ std::string BlockAnalysis::getInputName(const Buf* buf) const {
|
||||
}
|
||||
}
|
||||
|
||||
void BlockAnalysis::visit(const Store* v) {
|
||||
void BlockAnalysis::visit(Store* v) {
|
||||
store_targets_.insert(v->buf());
|
||||
v->value()->accept(this);
|
||||
}
|
||||
|
||||
void BlockAnalysis::visit(const Load* v) {
|
||||
void BlockAnalysis::visit(Load* v) {
|
||||
loads_.insert(v->buf());
|
||||
}
|
||||
|
||||
void BlockAnalysis::visit(const For* v) {
|
||||
void BlockAnalysis::visit(For* v) {
|
||||
const LoopOptions& loop_options = v->loop_options();
|
||||
if (loop_options.is_gpu_block_index()) {
|
||||
map_input_to_tensor_bufs_ = loop_options.get_buffer_mapping();
|
||||
v->body()->accept(this);
|
||||
} else if (loop_options.is_gpu_thread_index()) {
|
||||
auto block_size = v->stop();
|
||||
block_size_ = dynamic_cast<const IntImm*>(block_size)->value();
|
||||
block_size_ = dynamic_cast<IntImm*>(block_size)->value();
|
||||
v->body()->accept(this);
|
||||
} else {
|
||||
IRVisitor::visit(v);
|
||||
@ -91,26 +90,26 @@ void BlockAnalysis::visit(const For* v) {
|
||||
// TODO: When handling fused ops d = a + b + c, the correct
|
||||
// way would be to mutate the expression to Block version and print.
|
||||
|
||||
void BlockPrinter::visit(const Add* v) {
|
||||
void BlockPrinter::visit(Add* v) {
|
||||
emitIndent();
|
||||
os() << "add(";
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
}
|
||||
|
||||
void BlockPrinter::visit(const Mul* v) {
|
||||
void BlockPrinter::visit(Mul* v) {
|
||||
emitIndent();
|
||||
os() << "mul(";
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
}
|
||||
|
||||
void BlockPrinter::visit(const For* v) {
|
||||
void BlockPrinter::visit(For* v) {
|
||||
const LoopOptions& loop_options = v->loop_options();
|
||||
|
||||
auto buf_reads = block_analysis_->loads();
|
||||
auto buf_writes = block_analysis_->stores();
|
||||
std::unordered_set<const Buf*> bufs(buf_reads.begin(), buf_reads.end());
|
||||
std::unordered_set<Buf*> bufs(buf_reads.begin(), buf_reads.end());
|
||||
bufs.insert(buf_writes.begin(), buf_writes.end());
|
||||
|
||||
if (loop_options.is_gpu_block_index()) {
|
||||
@ -146,9 +145,9 @@ void BlockPrinter::visit(const For* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void BlockPrinter::PrintTensorInfo(const std::unordered_set<const Buf*>& bufs) {
|
||||
void BlockPrinter::PrintTensorInfo(const std::unordered_set<Buf*>& bufs) {
|
||||
os() << "tensors {";
|
||||
for (const auto& buf : bufs) {
|
||||
for (auto& buf : bufs) {
|
||||
os() << std::endl;
|
||||
emitIndent();
|
||||
emitIndent();
|
||||
@ -162,7 +161,7 @@ void BlockPrinter::PrintTensorInfo(const std::unordered_set<const Buf*>& bufs) {
|
||||
os() << "}";
|
||||
}
|
||||
|
||||
for (const auto& buf : bufs) {
|
||||
for (auto& buf : bufs) {
|
||||
os() << std::endl;
|
||||
emitIndent();
|
||||
emitIndent();
|
||||
@ -179,19 +178,19 @@ void BlockPrinter::PrintTensorInfo(const std::unordered_set<const Buf*>& bufs) {
|
||||
os() << "}" << std::endl << std::endl;
|
||||
}
|
||||
|
||||
void BlockPrinter::PrintArguments(const std::unordered_set<const Buf*>& bufs) {
|
||||
for (const auto& buf : bufs) {
|
||||
void BlockPrinter::PrintArguments(const std::unordered_set<Buf*>& bufs) {
|
||||
for (auto& buf : bufs) {
|
||||
auto multidimbuf = block_analysis_->getMultiDimBuf(buf);
|
||||
auto num_dims = multidimbuf->dims().size();
|
||||
|
||||
// The dims for the multi-dim tensors
|
||||
for (unsigned long d = 0; d < num_dims; d++) {
|
||||
auto dim_val = dynamic_cast<const IntImm*>(multidimbuf->dim(d));
|
||||
auto dim_val = dynamic_cast<IntImm*>(multidimbuf->dim(d));
|
||||
this->dim_values_map.emplace(this->dim_names[d], dim_val->value());
|
||||
}
|
||||
|
||||
// The dimensions for the flattened tensors
|
||||
auto val = dynamic_cast<const IntImm*>(buf->dim(0));
|
||||
auto val = dynamic_cast<IntImm*>(buf->dim(0));
|
||||
if (block_analysis_->is_buf_store_target(buf)) {
|
||||
this->dim_values_map.emplace(
|
||||
this->flat_dim_names[num_dims - 1], val->value());
|
||||
@ -217,10 +216,10 @@ void BlockPrinter::PrintArguments(const std::unordered_set<const Buf*>& bufs) {
|
||||
os() << "}" << std::endl << std::endl;
|
||||
}
|
||||
|
||||
void BlockPrinter::PrintBufferInfo(const std::unordered_set<const Buf*>& bufs) {
|
||||
void BlockPrinter::PrintBufferInfo(const std::unordered_set<Buf*>& bufs) {
|
||||
emitIndent();
|
||||
os() << "buffers {";
|
||||
for (const auto& read : bufs) {
|
||||
for (auto& read : bufs) {
|
||||
os() << std::endl;
|
||||
emitIndent();
|
||||
emitIndent();
|
||||
@ -234,11 +233,10 @@ void BlockPrinter::PrintBufferInfo(const std::unordered_set<const Buf*>& bufs) {
|
||||
os() << "}" << std::endl << std::endl;
|
||||
}
|
||||
|
||||
void BlockPrinter::PrintDistribution(
|
||||
const std::unordered_set<const Buf*>& bufs) {
|
||||
void BlockPrinter::PrintDistribution(const std::unordered_set<Buf*>& bufs) {
|
||||
emitIndent();
|
||||
os() << "distribution {" << std::endl;
|
||||
for (const auto& buf : bufs) {
|
||||
for (auto& buf : bufs) {
|
||||
emitIndent();
|
||||
emitIndent();
|
||||
auto buf_name = buf->name_hint();
|
||||
@ -249,12 +247,12 @@ void BlockPrinter::PrintDistribution(
|
||||
}
|
||||
|
||||
void BlockPrinter::PrintLoop(
|
||||
const std::unordered_set<const Buf*>& bufs,
|
||||
const std::unordered_set<Buf*>& bufs,
|
||||
bool block_idx) {
|
||||
emitIndent();
|
||||
os() << "loop (";
|
||||
auto trip = 0;
|
||||
for (const auto& buf : bufs) {
|
||||
for (auto& buf : bufs) {
|
||||
if (trip > 0) {
|
||||
os() << ",";
|
||||
}
|
||||
@ -267,9 +265,9 @@ void BlockPrinter::PrintLoop(
|
||||
}
|
||||
|
||||
void BlockPrinter::PrintReshapeInfo(
|
||||
const std::unordered_set<const Buf*>& bufs,
|
||||
const std::unordered_set<Buf*>& bufs,
|
||||
bool reverse) {
|
||||
for (const auto& buf : bufs) {
|
||||
for (auto& buf : bufs) {
|
||||
emitIndent();
|
||||
os() << "reshape("
|
||||
<< (reverse ? block_analysis_->getFlatInputName(buf)
|
||||
@ -281,17 +279,16 @@ void BlockPrinter::PrintReshapeInfo(
|
||||
}
|
||||
}
|
||||
|
||||
void BlockPrinter::PrintDMAs(const std::unordered_set<const Buf*>& bufs) {
|
||||
for (const auto& read : bufs) {
|
||||
void BlockPrinter::PrintDMAs(const std::unordered_set<Buf*>& bufs) {
|
||||
for (auto& read : bufs) {
|
||||
emitIndent();
|
||||
os() << "dma_in(";
|
||||
os() << block_analysis_->getFlatInputName(read);
|
||||
os() << ")" << std::endl;
|
||||
}
|
||||
}
|
||||
void BlockPrinter::PrintAdjustBuffers(
|
||||
const std::unordered_set<const Buf*>& bufs) {
|
||||
for (const auto& read : bufs) {
|
||||
void BlockPrinter::PrintAdjustBuffers(const std::unordered_set<Buf*>& bufs) {
|
||||
for (auto& read : bufs) {
|
||||
emitIndent();
|
||||
os() << "adjust_buffer(";
|
||||
os() << block_analysis_->getFlatInputName(read);
|
||||
@ -299,16 +296,16 @@ void BlockPrinter::PrintAdjustBuffers(
|
||||
}
|
||||
}
|
||||
|
||||
void BlockPrinter::visit(const Load* v) {
|
||||
void BlockPrinter::visit(Load* v) {
|
||||
os() << block_analysis_->getFlatInputName(v->buf()) << ".buffer, ";
|
||||
}
|
||||
void BlockPrinter::visit(const Store* v) {
|
||||
void BlockPrinter::visit(Store* v) {
|
||||
emitIndent();
|
||||
os() << *v->value() << block_analysis_->getFlatInputName(v->buf())
|
||||
<< ".tensor)" << std::endl;
|
||||
}
|
||||
|
||||
void BlockPrinter::visit(const Block* v) {
|
||||
void BlockPrinter::visit(Block* v) {
|
||||
os() << "{" << std::endl;
|
||||
indent_++;
|
||||
for (Stmt* s : v->stmts()) {
|
||||
@ -338,7 +335,7 @@ void BlockCodeGen::Initialize() {
|
||||
auto buf_reads = block_analysis_->loads();
|
||||
auto buf_writes = block_analysis_->stores();
|
||||
// Ensure all Bufs in reads/writes are in the map
|
||||
std::unordered_set<const Buf*> bufs(buf_reads.begin(), buf_reads.end());
|
||||
std::unordered_set<Buf*> bufs(buf_reads.begin(), buf_reads.end());
|
||||
bufs.insert(buf_writes.begin(), buf_writes.end());
|
||||
if (!block_analysis_->areBufsInMap(bufs)) {
|
||||
throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map");
|
||||
|
@ -20,15 +20,15 @@ namespace tensorexpr {
|
||||
// A class that analyzes the given program relevant for Block backend.
|
||||
class BlockAnalysis : public IRVisitor {
|
||||
public:
|
||||
bool is_buf_store_target(const Buf* buf) const {
|
||||
bool is_buf_store_target(Buf* buf) const {
|
||||
return store_targets_.count(buf) > 0;
|
||||
}
|
||||
|
||||
const std::unordered_set<const Buf*>& loads() const {
|
||||
const std::unordered_set<Buf*>& loads() const {
|
||||
return loads_;
|
||||
}
|
||||
|
||||
const std::unordered_set<const Buf*>& stores() const {
|
||||
const std::unordered_set<Buf*>& stores() const {
|
||||
return store_targets_;
|
||||
}
|
||||
|
||||
@ -36,64 +36,62 @@ class BlockAnalysis : public IRVisitor {
|
||||
return block_size_;
|
||||
}
|
||||
|
||||
bool areBufsInMap(const std::unordered_set<const Buf*>& bufs) const;
|
||||
bool areBufsInMap(const std::unordered_set<Buf*>& bufs) const;
|
||||
|
||||
const Buf* getMultiDimBuf(const Buf* buf) const;
|
||||
Buf* getMultiDimBuf(Buf* buf) const;
|
||||
|
||||
std::string getInputName(const Buf* buf) const;
|
||||
std::string getInputName(Buf* buf) const;
|
||||
|
||||
std::string getFlatInputName(const Buf* buf) const {
|
||||
std::string getFlatInputName(Buf* buf) const {
|
||||
return getInputName(buf) + "_flat";
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, const Buf*> getBufferMap() const {
|
||||
std::unordered_map<std::string, Buf*> getBufferMap() const {
|
||||
return map_input_to_tensor_bufs_;
|
||||
}
|
||||
|
||||
private:
|
||||
void visit(const Store* v) override;
|
||||
void visit(const Load* v) override;
|
||||
void visit(const For* v) override;
|
||||
void visit(Store* v) override;
|
||||
void visit(Load* v) override;
|
||||
void visit(For* v) override;
|
||||
|
||||
std::unordered_map<std::string, const Buf*> map_input_to_tensor_bufs_;
|
||||
std::unordered_set<const Buf*> store_targets_;
|
||||
std::unordered_set<const Buf*> loads_;
|
||||
std::unordered_map<std::string, Buf*> map_input_to_tensor_bufs_;
|
||||
std::unordered_set<Buf*> store_targets_;
|
||||
std::unordered_set<Buf*> loads_;
|
||||
int block_size_ = 32;
|
||||
};
|
||||
|
||||
// A class that overrides the underlying IRPrinter to produce Block.
|
||||
class BlockPrinter : public IRPrinter {
|
||||
public:
|
||||
BlockPrinter(std::ostream* os, const BlockAnalysis* block_analysis)
|
||||
BlockPrinter(std::ostream* os, BlockAnalysis* block_analysis)
|
||||
: IRPrinter(*os), block_analysis_(block_analysis) {}
|
||||
|
||||
using IRPrinter::name_manager;
|
||||
using IRPrinter::visit;
|
||||
|
||||
private:
|
||||
const BlockAnalysis* block_analysis_;
|
||||
BlockAnalysis* block_analysis_;
|
||||
std::unordered_map<std::string, int> dim_values_map;
|
||||
std::vector<std::string> dim_names = {"N", "H", "W", "C"};
|
||||
std::vector<std::string> flat_dim_names = {"N", "NH", "NHW", "NHWC"};
|
||||
void PrintTensorInfo(const std::unordered_set<const Buf*>& bufs);
|
||||
void PrintArguments(const std::unordered_set<const Buf*>& bufs);
|
||||
void PrintBufferInfo(const std::unordered_set<const Buf*>& bufs);
|
||||
void PrintDistribution(const std::unordered_set<const Buf*>& bufs);
|
||||
void PrintLoop(
|
||||
const std::unordered_set<const Buf*>& bufs,
|
||||
bool block_idx = true);
|
||||
void PrintTensorInfo(const std::unordered_set<Buf*>& bufs);
|
||||
void PrintArguments(const std::unordered_set<Buf*>& bufs);
|
||||
void PrintBufferInfo(const std::unordered_set<Buf*>& bufs);
|
||||
void PrintDistribution(const std::unordered_set<Buf*>& bufs);
|
||||
void PrintLoop(const std::unordered_set<Buf*>& bufs, bool block_idx = true);
|
||||
void PrintReshapeInfo(
|
||||
const std::unordered_set<const Buf*>& bufs,
|
||||
const std::unordered_set<Buf*>& bufs,
|
||||
bool reverse = false);
|
||||
void PrintDMAs(const std::unordered_set<const Buf*>& bufs);
|
||||
void PrintAdjustBuffers(const std::unordered_set<const Buf*>& bufs);
|
||||
void PrintDMAs(const std::unordered_set<Buf*>& bufs);
|
||||
void PrintAdjustBuffers(const std::unordered_set<Buf*>& bufs);
|
||||
|
||||
void visit(const For* v) override;
|
||||
void visit(const Load* v) override;
|
||||
void visit(const Store* v) override;
|
||||
void visit(const Block* v) override;
|
||||
void visit(const Add* v) override;
|
||||
void visit(const Mul* v) override;
|
||||
void visit(For* v) override;
|
||||
void visit(Load* v) override;
|
||||
void visit(Store* v) override;
|
||||
void visit(Block* v) override;
|
||||
void visit(Add* v) override;
|
||||
void visit(Mul* v) override;
|
||||
};
|
||||
|
||||
class TORCH_API BlockCodeGen : public CodeGen {
|
||||
|
@ -19,7 +19,7 @@ using namespace analysis;
|
||||
template <typename Container>
|
||||
BoundsInfo mergeTensorAccesses(
|
||||
const Container& accesses,
|
||||
const std::unordered_map<const Var*, const Buf*>& varToBuf,
|
||||
const std::unordered_map<Var*, Buf*>& varToBuf,
|
||||
bool distinctAccessKinds) {
|
||||
BoundsInfo ret;
|
||||
for (auto& access : accesses) {
|
||||
@ -30,7 +30,7 @@ BoundsInfo mergeTensorAccesses(
|
||||
|
||||
auto vtbIt = varToBuf.find(access->var());
|
||||
TORCH_INTERNAL_ASSERT(vtbIt != varToBuf.end());
|
||||
const Buf* buf = vtbIt->second;
|
||||
Buf* buf = vtbIt->second;
|
||||
std::vector<TensorAccessBoundsInfo>& infos = ret[buf];
|
||||
|
||||
bool added = false;
|
||||
@ -70,20 +70,20 @@ BoundsInfo mergeTensorAccesses(
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::unordered_map<const Var*, const Buf*> getAllBufs(Stmt* s) {
|
||||
std::unordered_map<const Var*, const Buf*> varToBuf;
|
||||
std::unordered_map<Var*, Buf*> getAllBufs(Stmt* s) {
|
||||
std::unordered_map<Var*, Buf*> varToBuf;
|
||||
|
||||
auto bufs = NodeFinder<const Buf>::find(s);
|
||||
auto bufs = NodeFinder<Buf>::find(s);
|
||||
for (auto* b : bufs) {
|
||||
varToBuf[b->base_handle()] = b;
|
||||
}
|
||||
return varToBuf;
|
||||
}
|
||||
|
||||
std::unordered_map<const Var*, const Buf*> getAllBufs(Expr* e) {
|
||||
std::unordered_map<const Var*, const Buf*> varToBuf;
|
||||
std::unordered_map<Var*, Buf*> getAllBufs(Expr* e) {
|
||||
std::unordered_map<Var*, Buf*> varToBuf;
|
||||
|
||||
auto bufs = NodeFinder<const Buf>::find(e);
|
||||
auto bufs = NodeFinder<Buf>::find(e);
|
||||
for (auto* b : bufs) {
|
||||
varToBuf[b->base_handle()] = b;
|
||||
}
|
||||
@ -121,7 +121,7 @@ void printBoundsInfo(const BoundsInfo& v) {
|
||||
for (auto& pair : v) {
|
||||
std::cerr << *pair.first << " in [";
|
||||
bool first = true;
|
||||
for (const auto& b : pair.second) {
|
||||
for (auto& b : pair.second) {
|
||||
if (!first) {
|
||||
std::cerr << ", ";
|
||||
}
|
||||
@ -130,7 +130,7 @@ void printBoundsInfo(const BoundsInfo& v) {
|
||||
if (b.start.empty()) {
|
||||
std::cerr << "0";
|
||||
}
|
||||
for (const auto& s : b.start) {
|
||||
for (auto& s : b.start) {
|
||||
if (i != 0) {
|
||||
std::cerr << ", ";
|
||||
}
|
||||
@ -142,7 +142,7 @@ void printBoundsInfo(const BoundsInfo& v) {
|
||||
if (b.stop.empty()) {
|
||||
std::cerr << "0";
|
||||
}
|
||||
for (const auto& s : b.stop) {
|
||||
for (auto& s : b.stop) {
|
||||
if (i != 0) {
|
||||
std::cerr << ", ";
|
||||
}
|
||||
@ -157,15 +157,15 @@ void printBoundsInfo(const BoundsInfo& v) {
|
||||
std::cerr << "}\n";
|
||||
}
|
||||
|
||||
std::vector<const Expr*> getBoundExtents(
|
||||
std::vector<Expr*> getBoundExtents(
|
||||
const std::vector<TensorAccessBoundsInfo>& infos) {
|
||||
std::vector<const Expr*> starts;
|
||||
std::vector<const Expr*> stops;
|
||||
std::vector<Expr*> starts;
|
||||
std::vector<Expr*> stops;
|
||||
|
||||
// Find the safe size of the temprorary buffer by determining the outer
|
||||
// extents of a union of all bounds.
|
||||
for (const TensorAccessBoundsInfo& p : infos) {
|
||||
for (const auto i : c10::irange(p.start.size())) {
|
||||
for (auto i : c10::irange(p.start.size())) {
|
||||
if (starts.size() <= i) {
|
||||
starts.push_back(p.start[i]);
|
||||
} else {
|
||||
@ -181,9 +181,9 @@ std::vector<const Expr*> getBoundExtents(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const Expr*> extents;
|
||||
std::vector<Expr*> extents;
|
||||
for (size_t i = 0; i < starts.size(); ++i) {
|
||||
const Expr* dim = IRSimplifier::simplify(
|
||||
Expr* dim = IRSimplifier::simplify(
|
||||
new Add(new Sub(stops[i], starts[i]), new IntImm(1)));
|
||||
|
||||
extents.push_back(dim);
|
||||
@ -210,7 +210,7 @@ BoundSet convertBounds(
|
||||
|
||||
BoundSet convertBounds(
|
||||
BoundsInfo& bounds,
|
||||
const Buf* buf,
|
||||
Buf* buf,
|
||||
TensorAccessKind filter = kMutate) {
|
||||
auto it = bounds.find(buf);
|
||||
if (it == bounds.end()) {
|
||||
@ -231,7 +231,7 @@ HazardKind getPotentialHazards(
|
||||
BoundSet aReads;
|
||||
|
||||
for (auto& pair : bBounds) {
|
||||
const Buf* buf = pair.first;
|
||||
Buf* buf = pair.first;
|
||||
if (aBounds.find(buf) == aBounds.end()) {
|
||||
continue;
|
||||
}
|
||||
@ -302,18 +302,17 @@ bool hasConflictingOverlap(
|
||||
const BoundsInfo& bBounds,
|
||||
TensorAccessKind aFilter = kMutate,
|
||||
TensorAccessKind bFilter = kMutate) {
|
||||
using IndexBoundsInfo =
|
||||
std::unordered_map<const Buf*, std::vector<IndexBounds>>;
|
||||
using IndexBoundsInfo = std::unordered_map<Buf*, std::vector<IndexBounds>>;
|
||||
IndexBoundsInfo aIndexBoundsInfo;
|
||||
for (const auto& aBound : aBounds) {
|
||||
for (auto& aBound : aBounds) {
|
||||
aIndexBoundsInfo[aBound.first] = getIndexBounds(aBound.second, aFilter);
|
||||
}
|
||||
IndexBoundsInfo bIndexBoundsInfo;
|
||||
for (const auto& bBound : bBounds) {
|
||||
for (auto& bBound : bBounds) {
|
||||
bIndexBoundsInfo[bBound.first] = getIndexBounds(bBound.second, bFilter);
|
||||
}
|
||||
|
||||
for (const auto& aBound : aBounds) {
|
||||
for (auto& aBound : aBounds) {
|
||||
auto bIt = bBounds.find(aBound.first);
|
||||
if (bIt == bBounds.end()) {
|
||||
continue;
|
||||
|
@ -20,12 +20,12 @@ enum C10_API_ENUM TensorAccessKind { kLoad, kStore, kMutate };
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
struct TORCH_API TensorAccessBoundsInfo {
|
||||
TensorAccessKind kind;
|
||||
std::vector<const Expr*> start;
|
||||
std::vector<const Expr*> stop;
|
||||
std::vector<Expr*> start;
|
||||
std::vector<Expr*> stop;
|
||||
};
|
||||
|
||||
using BoundsInfo =
|
||||
std::unordered_map<const Buf*, std::vector<TensorAccessBoundsInfo>>;
|
||||
std::unordered_map<Buf*, std::vector<TensorAccessBoundsInfo>>;
|
||||
|
||||
TORCH_API BoundsInfo inferBounds(Stmt* s, bool distinctAccessKinds = true);
|
||||
|
||||
@ -42,7 +42,7 @@ TORCH_API BoundsInfo getInferredBounds(
|
||||
|
||||
TORCH_API void printBoundsInfo(const BoundsInfo& v);
|
||||
|
||||
TORCH_API std::vector<const Expr*> getBoundExtents(
|
||||
TORCH_API std::vector<Expr*> getBoundExtents(
|
||||
const std::vector<TensorAccessBoundsInfo>& infos);
|
||||
|
||||
// The kind of dependency found, in increasing order of exclusivity.
|
||||
|
@ -14,8 +14,8 @@ OverlapKind boundOverlap(Bound a, Bound b) {
|
||||
return ContainedOrEqual;
|
||||
}
|
||||
|
||||
const Expr* lowDiff = IRSimplifier::simplify(new Sub(a.start, b.end));
|
||||
const Expr* highDiff = IRSimplifier::simplify(new Sub(b.start, a.end));
|
||||
Expr* lowDiff = IRSimplifier::simplify(new Sub(a.start, b.end));
|
||||
Expr* highDiff = IRSimplifier::simplify(new Sub(b.start, a.end));
|
||||
|
||||
if (lowDiff->isConstant() && highDiff->isConstant()) {
|
||||
int low = immediateAs<int>(lowDiff);
|
||||
@ -26,8 +26,8 @@ OverlapKind boundOverlap(Bound a, Bound b) {
|
||||
}
|
||||
}
|
||||
|
||||
const Expr* diff_start = IRSimplifier::simplify(new Sub(b.start, a.start));
|
||||
const Expr* diff_end = IRSimplifier::simplify(new Sub(b.end, a.end));
|
||||
Expr* diff_start = IRSimplifier::simplify(new Sub(b.start, a.start));
|
||||
Expr* diff_end = IRSimplifier::simplify(new Sub(b.end, a.end));
|
||||
|
||||
// If one side fully encloses the other, they're adjacent.
|
||||
if (diff_start->isConstant() && diff_end->isConstant()) {
|
||||
@ -122,8 +122,8 @@ std::vector<Bound> subtractBound(Bound a, Bound b, OverlapKind overlap) {
|
||||
return {a};
|
||||
}
|
||||
|
||||
const Expr* lowDiff = IRSimplifier::simplify(new Sub(b.start, a.start));
|
||||
const Expr* highDiff = IRSimplifier::simplify(new Sub(b.end, a.end));
|
||||
Expr* lowDiff = IRSimplifier::simplify(new Sub(b.start, a.start));
|
||||
Expr* highDiff = IRSimplifier::simplify(new Sub(b.end, a.end));
|
||||
|
||||
// If the diff has only a single var, we can try to guess sign.
|
||||
if (!lowDiff->isConstant()) {
|
||||
@ -161,8 +161,7 @@ std::vector<Bound> subtractBound(Bound a, Bound b, OverlapKind overlap) {
|
||||
}
|
||||
|
||||
if (hasTail) {
|
||||
const Expr* tailStart =
|
||||
IRSimplifier::simplify(new Add(b.end, new IntImm(1)));
|
||||
Expr* tailStart = IRSimplifier::simplify(new Add(b.end, new IntImm(1)));
|
||||
res.emplace_back(tailStart, a.end);
|
||||
}
|
||||
|
||||
|
@ -13,8 +13,8 @@ namespace analysis {
|
||||
|
||||
// A simple class containing the start and end of a range in a single dimension.
|
||||
struct TORCH_API Bound {
|
||||
const Expr* start{nullptr};
|
||||
const Expr* end{nullptr};
|
||||
Expr* start{nullptr};
|
||||
Expr* end{nullptr};
|
||||
|
||||
// This stores whether or not the start and end of this Bound have previously
|
||||
// been swapped. This occurs when the bound is in a loop with a negative
|
||||
@ -22,7 +22,7 @@ struct TORCH_API Bound {
|
||||
bool swapped{false};
|
||||
|
||||
Bound() = default;
|
||||
Bound(const Expr* s, const Expr* e) : start(s), end(e) {}
|
||||
Bound(Expr* s, Expr* e) : start(s), end(e) {}
|
||||
|
||||
void print() const {
|
||||
std::cout << "(" << *start << ", " << *end << ")";
|
||||
@ -44,7 +44,7 @@ struct TORCH_API Bound {
|
||||
|
||||
struct BoundHash {
|
||||
size_t operator()(const Bound& b) const {
|
||||
return std::hash<const Expr*>()(b.start) ^ std::hash<const Expr*>()(b.end);
|
||||
return std::hash<Expr*>()(b.start) ^ std::hash<Expr*>()(b.end);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -14,7 +14,7 @@ RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList::
|
||||
oss << "Invalid stmt codegen name: " << name << ". ";
|
||||
oss << "Existing codegen names: [";
|
||||
int index = 0;
|
||||
for (const auto& entry : stmt_factory_methods_) {
|
||||
for (auto& entry : stmt_factory_methods_) {
|
||||
if (index != 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
@ -44,7 +44,7 @@ std::unique_ptr<CodeGen> CreateCodeGen(
|
||||
return method(stmt, params, device, kernel_func_name);
|
||||
}
|
||||
|
||||
const Expr* GenericIntrinsicsExpander::mutate(const Intrinsics* v) {
|
||||
Expr* GenericIntrinsicsExpander::mutate(Intrinsics* v) {
|
||||
if (v->op_type() == kSigmoid) {
|
||||
auto x = v->param(0)->accept_mutator(this);
|
||||
auto one = expr_to_vec(
|
||||
|
@ -108,11 +108,11 @@ class CodeGen::BufferArg {
|
||||
BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {}
|
||||
BufferArg(const BufHandle& buf) : buf_(buf.node()) {}
|
||||
|
||||
const Var* var() const {
|
||||
Var* var() const {
|
||||
return isVar_ ? var_ : buf_->base_handle();
|
||||
}
|
||||
|
||||
const Buf* buf() const {
|
||||
Buf* buf() const {
|
||||
return buf_;
|
||||
}
|
||||
|
||||
@ -125,8 +125,8 @@ class CodeGen::BufferArg {
|
||||
}
|
||||
|
||||
private:
|
||||
const Var* var_ = nullptr;
|
||||
const Buf* buf_ = nullptr;
|
||||
Var* var_ = nullptr;
|
||||
Buf* buf_ = nullptr;
|
||||
bool isVar_ = false;
|
||||
};
|
||||
|
||||
@ -226,7 +226,7 @@ TORCH_API std::unique_ptr<CodeGen> CreateCodeGen(
|
||||
|
||||
class TORCH_API GenericIntrinsicsExpander : public IRMutator {
|
||||
protected:
|
||||
const Expr* mutate(const Intrinsics* v) override;
|
||||
Expr* mutate(Intrinsics* v) override;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -4,12 +4,12 @@ namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
void CppPrinter::visit(const Allocate* alloc) {
|
||||
void CppPrinter::visit(Allocate* alloc) {
|
||||
constexpr size_t kAllocOnStackThresholdSize = 512;
|
||||
|
||||
size_t size = 1;
|
||||
for (auto dim : alloc->dims()) {
|
||||
const IntImm* v = dynamic_cast<const IntImm*>(dim);
|
||||
IntImm* v = dynamic_cast<IntImm*>(dim);
|
||||
if (v) {
|
||||
size *= v->value();
|
||||
} else {
|
||||
@ -30,8 +30,8 @@ void CppPrinter::visit(const Allocate* alloc) {
|
||||
}
|
||||
}
|
||||
|
||||
void CppPrinter::visit(const Free* free) {
|
||||
const Var* var = free->buffer_var();
|
||||
void CppPrinter::visit(Free* free) {
|
||||
Var* var = free->buffer_var();
|
||||
if (allocated_on_heap_.count(var)) {
|
||||
emitIndent();
|
||||
os() << "free(" << name_manager()->get_unique_name(var) << ");"
|
||||
|
@ -14,11 +14,11 @@ class TORCH_API CppPrinter : public IRPrinter {
|
||||
explicit CppPrinter(std::ostream* os) : IRPrinter(*os) {}
|
||||
|
||||
using IRPrinter::visit;
|
||||
void visit(const Allocate*) override;
|
||||
void visit(const Free*) override;
|
||||
void visit(Allocate*) override;
|
||||
void visit(Free*) override;
|
||||
|
||||
private:
|
||||
std::unordered_set<const Var*> allocated_on_heap_;
|
||||
std::unordered_set<Var*> allocated_on_heap_;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -21,7 +21,7 @@ namespace tensorexpr {
|
||||
// TODO: move this to a more shared place.
|
||||
class ScopedVarName {
|
||||
public:
|
||||
ScopedVarName(VarNameMap* mapping, const Var* var, const std::string& name)
|
||||
ScopedVarName(VarNameMap* mapping, Var* var, const std::string& name)
|
||||
: mapping_(mapping), var_(var) {
|
||||
auto iter = mapping->find(var);
|
||||
if (iter != mapping->end()) {
|
||||
@ -30,10 +30,7 @@ class ScopedVarName {
|
||||
mapping->insert(std::make_pair(var, name));
|
||||
}
|
||||
|
||||
ScopedVarName(
|
||||
UniqueNameManager* manager,
|
||||
const Var* var,
|
||||
const std::string& name)
|
||||
ScopedVarName(UniqueNameManager* manager, Var* var, const std::string& name)
|
||||
: ScopedVarName(&manager->unique_name_mapping_, var, name) {}
|
||||
|
||||
ScopedVarName(const ScopedVarName&) = delete;
|
||||
@ -45,11 +42,11 @@ class ScopedVarName {
|
||||
|
||||
private:
|
||||
VarNameMap* mapping_ = nullptr;
|
||||
const Var* var_ = nullptr;
|
||||
Var* var_ = nullptr;
|
||||
};
|
||||
|
||||
static int as_int(const Expr* expr) {
|
||||
auto v = dynamic_cast<const IntImm*>(expr);
|
||||
static int as_int(Expr* expr) {
|
||||
auto v = dynamic_cast<IntImm*>(expr);
|
||||
if (!v) {
|
||||
throw malformed_input(
|
||||
"cuda_codegen: non Int expr interpreted as int", expr);
|
||||
@ -58,7 +55,7 @@ static int as_int(const Expr* expr) {
|
||||
return v->value();
|
||||
}
|
||||
|
||||
static bool is_zero(const Expr* expr) {
|
||||
static bool is_zero(Expr* expr) {
|
||||
return as_int(expr) == 0;
|
||||
}
|
||||
|
||||
@ -123,17 +120,17 @@ std::string CudaPrinter::dtypeToCppString(const Dtype& dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
void CudaAnalysis::visit(const Free* v) {
|
||||
void CudaAnalysis::visit(Free* v) {
|
||||
if (thread_local_bufs_.count(v->buffer_var()) == 0 &&
|
||||
cross_block_bufs_.count(v->buffer_var()) == 0) {
|
||||
throw std::runtime_error("Global free not supported yet");
|
||||
}
|
||||
}
|
||||
|
||||
void CudaAnalysis::visit(const Allocate* v) {
|
||||
void CudaAnalysis::visit(Allocate* v) {
|
||||
Stmt* p = v->get_parent();
|
||||
while (p) {
|
||||
const For* for_v = dynamic_cast<const For*>(p);
|
||||
For* for_v = dynamic_cast<For*>(p);
|
||||
if (for_v) {
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
if (for_v->loop_options().is_gpu_block_index()) {
|
||||
@ -151,7 +148,7 @@ void CudaAnalysis::visit(const Allocate* v) {
|
||||
throw std::runtime_error("Global alloc not supported yet");
|
||||
}
|
||||
|
||||
void CudaAnalysis::visit(const For* v) {
|
||||
void CudaAnalysis::visit(For* v) {
|
||||
// Recurse first.
|
||||
v->body()->accept(this);
|
||||
|
||||
@ -161,7 +158,7 @@ void CudaAnalysis::visit(const For* v) {
|
||||
if (gpu_block_index >= 3) {
|
||||
throw std::runtime_error("support only 3D gpu_block_index");
|
||||
}
|
||||
const Expr* prev = nullptr;
|
||||
Expr* prev = nullptr;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
if (gpu_block_extents_.size() <= gpu_block_index) {
|
||||
@ -191,7 +188,7 @@ void CudaAnalysis::visit(const For* v) {
|
||||
if (gpu_thread_index >= 3) {
|
||||
throw std::runtime_error("support only 3D gpu_thread_index");
|
||||
}
|
||||
const Expr* prev = nullptr;
|
||||
Expr* prev = nullptr;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
if (gpu_thread_extents_.size() <= gpu_thread_index) {
|
||||
@ -219,13 +216,13 @@ void CudaAnalysis::visit(const For* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void CudaPrinter::print_flat_alloc(const Allocate* alloc) {
|
||||
void CudaPrinter::print_flat_alloc(Allocate* alloc) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Expr*> dims = alloc->dims();
|
||||
std::vector<Expr*> dims = alloc->dims();
|
||||
// TODO: this should be merged with the storage flattener.
|
||||
int64_t flat_size = 1;
|
||||
for (auto dim : dims) {
|
||||
const IntImm* dim_i = dynamic_cast<const IntImm*>(dim);
|
||||
IntImm* dim_i = dynamic_cast<IntImm*>(dim);
|
||||
if (dim_i) {
|
||||
flat_size *= dim_i->value();
|
||||
} else {
|
||||
@ -236,7 +233,7 @@ void CudaPrinter::print_flat_alloc(const Allocate* alloc) {
|
||||
<< "[" << flat_size << "];" << std::endl;
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Allocate* v) {
|
||||
void CudaPrinter::visit(Allocate* v) {
|
||||
// TODO: handle dynamic shapes here.
|
||||
if (cuda_analysis_->cross_block_bufs().count(v->buffer_var()) != 0) {
|
||||
emitIndent();
|
||||
@ -254,15 +251,15 @@ void CudaPrinter::visit(const Allocate* v) {
|
||||
throw std::runtime_error("Encountered Alloc not local to block or thread");
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Free* v) {
|
||||
void CudaPrinter::visit(Free* v) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const For* v) {
|
||||
void CudaPrinter::visit(For* v) {
|
||||
IRPrinter::visit(v);
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Cast* v) {
|
||||
void CudaPrinter::visit(Cast* v) {
|
||||
if (v->dtype().scalar_type() == ScalarType::Half) {
|
||||
os() << "__float2half(";
|
||||
v->src_value()->accept(this);
|
||||
@ -281,7 +278,7 @@ void CudaPrinter::visit(const Cast* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Intrinsics* v) {
|
||||
void CudaPrinter::visit(Intrinsics* v) {
|
||||
if (v->op_type() == IntrinsicsOp::kRand) {
|
||||
os() << "Uint32ToFloat(" << *rand_func_ << "())";
|
||||
return;
|
||||
@ -308,7 +305,7 @@ void CudaPrinter::visit(const Intrinsics* v) {
|
||||
}
|
||||
|
||||
os() << func_name << "(";
|
||||
for (const auto i : c10::irange(v->nparams())) {
|
||||
for (auto i : c10::irange(v->nparams())) {
|
||||
if (i > 0) {
|
||||
os() << ", ";
|
||||
}
|
||||
@ -317,11 +314,11 @@ void CudaPrinter::visit(const Intrinsics* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const ExternalCall* v) {
|
||||
void CudaPrinter::visit(ExternalCall* v) {
|
||||
throw unimplemented_lowering(v);
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Load* v) {
|
||||
void CudaPrinter::visit(Load* v) {
|
||||
// TODO: find a better metric in using ldg or not. Support different dtypes.
|
||||
// Detects whether the load target is also a store target.
|
||||
// TODO: this is currently too wide. It detects whether a store-target
|
||||
@ -348,7 +345,7 @@ void CudaPrinter::visit(const Load* v) {
|
||||
// TODO: maybe this should be a more shared location?
|
||||
// TODO: investigate how "Expr*" can be implicitly converted to "ExprHandle" as
|
||||
// a bool.
|
||||
static bool CheckEqual(const Expr* lhs, const Expr* rhs) {
|
||||
static bool CheckEqual(Expr* lhs, Expr* rhs) {
|
||||
// The fast path. Checks if the pointers are the same.
|
||||
if (lhs == rhs) {
|
||||
return true;
|
||||
@ -362,12 +359,11 @@ class AtomicAddFuser : public IRMutator {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AtomicAddFuser(
|
||||
const std::unordered_set<const Var*>& thread_local_bufs,
|
||||
const std::unordered_set<Var*>& thread_local_bufs,
|
||||
const GPUMetaVarRewriter& metavars)
|
||||
: thread_local_bufs_(thread_local_bufs) {
|
||||
const std::vector<const Expr*>& block_extents =
|
||||
metavars.gpu_block_extents();
|
||||
const std::vector<const Var*>& block_vars = metavars.gpu_block_vars();
|
||||
const std::vector<Expr*>& block_extents = metavars.gpu_block_extents();
|
||||
const std::vector<Var*>& block_vars = metavars.gpu_block_vars();
|
||||
for (size_t i = 0; i < block_extents.size(); ++i) {
|
||||
MetaVarExtent extent{block_extents[i], false};
|
||||
if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
|
||||
@ -378,9 +374,8 @@ class AtomicAddFuser : public IRMutator {
|
||||
metavars_[block_vars[i]] = extent;
|
||||
}
|
||||
|
||||
const std::vector<const Expr*>& thread_extents =
|
||||
metavars.gpu_thread_extents();
|
||||
const std::vector<const Var*>& thread_vars = metavars.gpu_thread_vars();
|
||||
const std::vector<Expr*>& thread_extents = metavars.gpu_thread_extents();
|
||||
const std::vector<Var*>& thread_vars = metavars.gpu_thread_vars();
|
||||
for (size_t i = 0; i < thread_extents.size(); ++i) {
|
||||
MetaVarExtent extent{thread_extents[i], false};
|
||||
if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
|
||||
@ -392,8 +387,8 @@ class AtomicAddFuser : public IRMutator {
|
||||
}
|
||||
}
|
||||
|
||||
Stmt* mutate(const Store* v) override {
|
||||
const Buf* buf = v->buf();
|
||||
Stmt* mutate(Store* v) override {
|
||||
Buf* buf = v->buf();
|
||||
Store* orig = const_cast<Store*>(v); // NOLINT
|
||||
|
||||
// Thread locals never need to be atomic.
|
||||
@ -405,11 +400,11 @@ class AtomicAddFuser : public IRMutator {
|
||||
if (dtype != ScalarType::Float && dtype != ScalarType::Double) {
|
||||
return orig;
|
||||
}
|
||||
const Add* add_v = dynamic_cast<const Add*>(v->value());
|
||||
Add* add_v = dynamic_cast<Add*>(v->value());
|
||||
if (!add_v) {
|
||||
return orig;
|
||||
}
|
||||
const Load* load_v = dynamic_cast<const Load*>(add_v->lhs());
|
||||
Load* load_v = dynamic_cast<Load*>(add_v->lhs());
|
||||
if (!load_v) {
|
||||
return orig;
|
||||
}
|
||||
@ -427,9 +422,9 @@ class AtomicAddFuser : public IRMutator {
|
||||
// TODO: this checks that the metavars occur directly as an index, but this
|
||||
// is pessimistic, blockIdx.x + 1 is fine too if there is no overlapping.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::unordered_set<const Var*> vars_to_find = nontrivial_metavars_;
|
||||
for (const Expr* e : v->indices()) {
|
||||
if (const Var* v = dynamic_cast<const Var*>(e)) {
|
||||
std::unordered_set<Var*> vars_to_find = nontrivial_metavars_;
|
||||
for (Expr* e : v->indices()) {
|
||||
if (Var* v = dynamic_cast<Var*>(e)) {
|
||||
vars_to_find.erase(v);
|
||||
}
|
||||
}
|
||||
@ -443,16 +438,16 @@ class AtomicAddFuser : public IRMutator {
|
||||
}
|
||||
|
||||
private:
|
||||
const std::unordered_set<const Var*>& thread_local_bufs_;
|
||||
const std::unordered_set<Var*>& thread_local_bufs_;
|
||||
struct MetaVarExtent {
|
||||
const Expr* expr{nullptr};
|
||||
Expr* expr{nullptr};
|
||||
bool trivial{false};
|
||||
};
|
||||
std::unordered_map<const Var*, MetaVarExtent> metavars_;
|
||||
std::unordered_set<const Var*> nontrivial_metavars_;
|
||||
std::unordered_map<Var*, MetaVarExtent> metavars_;
|
||||
std::unordered_set<Var*> nontrivial_metavars_;
|
||||
};
|
||||
|
||||
void CudaPrinter::visit(const Store* v) {
|
||||
void CudaPrinter::visit(Store* v) {
|
||||
emitIndent();
|
||||
if (v->indices().empty()) {
|
||||
os() << *v->base_handle() << " = ";
|
||||
@ -463,7 +458,7 @@ void CudaPrinter::visit(const Store* v) {
|
||||
os() << std::endl;
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const AtomicAdd* v) {
|
||||
void CudaPrinter::visit(AtomicAdd* v) {
|
||||
emitIndent();
|
||||
if (cuda_analysis_->thread_local_bufs().count(v->base_handle()) > 0) {
|
||||
// atomicAdd only works on global and shared memory
|
||||
@ -476,7 +471,7 @@ void CudaPrinter::visit(const AtomicAdd* v) {
|
||||
os() << std::endl;
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Max* v) {
|
||||
void CudaPrinter::visit(Max* v) {
|
||||
if (v->dtype().is_integral()) {
|
||||
os() << "max(";
|
||||
} else {
|
||||
@ -488,7 +483,7 @@ void CudaPrinter::visit(const Max* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Min* v) {
|
||||
void CudaPrinter::visit(Min* v) {
|
||||
if (v->dtype().is_integral()) {
|
||||
os() << "min(";
|
||||
} else {
|
||||
@ -500,7 +495,7 @@ void CudaPrinter::visit(const Min* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const IfThenElse* v) {
|
||||
void CudaPrinter::visit(IfThenElse* v) {
|
||||
os() << "((";
|
||||
v->condition()->accept(this);
|
||||
os() << ") ? ";
|
||||
@ -510,7 +505,7 @@ void CudaPrinter::visit(const IfThenElse* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Block* v) {
|
||||
void CudaPrinter::visit(Block* v) {
|
||||
os() << "{" << std::endl;
|
||||
indent_++;
|
||||
|
||||
@ -523,7 +518,7 @@ void CudaPrinter::visit(const Block* v) {
|
||||
os() << "}";
|
||||
}
|
||||
|
||||
void CudaPrinter::visit(const Let* v) {
|
||||
void CudaPrinter::visit(Let* v) {
|
||||
emitIndent();
|
||||
os() << dtypeToCppString(v->dtype());
|
||||
os() << " " << *v->var() << " = ";
|
||||
@ -534,7 +529,7 @@ void CudaPrinter::visit(const Let* v) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
class PrioritizeLoad : public IRMutator {
|
||||
public:
|
||||
const Expr* mutate(const Load* v) override {
|
||||
Expr* mutate(Load* v) override {
|
||||
// Look at the declaration of this variable for more details.
|
||||
if (nested_if_then_else_ > 0) {
|
||||
return IRMutator::mutate(v);
|
||||
@ -569,17 +564,17 @@ class PrioritizeLoad : public IRMutator {
|
||||
}
|
||||
|
||||
MemLoadList& load_list = load_stack_.back();
|
||||
const Var* load_new_var = new Var("v", v->dtype());
|
||||
const Expr* new_value = IRMutator::mutate(v);
|
||||
Var* load_new_var = new Var("v", v->dtype());
|
||||
Expr* new_value = IRMutator::mutate(v);
|
||||
load_list.push_back(std::make_pair(load_new_var, new_value));
|
||||
|
||||
return load_new_var;
|
||||
}
|
||||
|
||||
const Expr* mutate(const Cast* v) override {
|
||||
const Load* src_load = dynamic_cast<const Load*>(v->src_value());
|
||||
const Expr* new_src = v->src_value()->accept_mutator(this);
|
||||
const Var* new_var = dynamic_cast<const Var*>(new_src);
|
||||
Expr* mutate(Cast* v) override {
|
||||
Load* src_load = dynamic_cast<Load*>(v->src_value());
|
||||
Expr* new_src = v->src_value()->accept_mutator(this);
|
||||
Var* new_var = dynamic_cast<Var*>(new_src);
|
||||
if (!src_load || !new_var) {
|
||||
return new Cast(v->dtype(), new_src);
|
||||
}
|
||||
@ -593,27 +588,27 @@ class PrioritizeLoad : public IRMutator {
|
||||
|
||||
new_var = new Var("v", v->dtype());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
const Expr* new_value = new Cast(v->dtype(), pair.second);
|
||||
Expr* new_value = new Cast(v->dtype(), pair.second);
|
||||
load_list.push_back(std::make_pair(new_var, new_value));
|
||||
return new_var;
|
||||
}
|
||||
|
||||
Stmt* mutate(const Store* v) override {
|
||||
const Store* last = nested_store_;
|
||||
Stmt* mutate(Store* v) override {
|
||||
Store* last = nested_store_;
|
||||
nested_store_ = v;
|
||||
Stmt* s = IRMutator::mutate(v);
|
||||
nested_store_ = last;
|
||||
return s;
|
||||
}
|
||||
|
||||
Stmt* mutate(const Let* v) override {
|
||||
Stmt* mutate(Let* v) override {
|
||||
nested_let_ = true;
|
||||
Stmt* s = IRMutator::mutate(v);
|
||||
nested_let_ = false;
|
||||
return s;
|
||||
}
|
||||
|
||||
Stmt* mutate(const Block* v) override {
|
||||
Stmt* mutate(Block* v) override {
|
||||
Block* v1 = const_cast<Block*>(v); // NOLINT
|
||||
assert(v1);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
@ -633,15 +628,15 @@ class PrioritizeLoad : public IRMutator {
|
||||
return v1;
|
||||
}
|
||||
|
||||
const Expr* mutate(const IfThenElse* v) override {
|
||||
Expr* mutate(IfThenElse* v) override {
|
||||
nested_if_then_else_++;
|
||||
const Expr* new_v = IRMutator::mutate(v);
|
||||
Expr* new_v = IRMutator::mutate(v);
|
||||
nested_if_then_else_--;
|
||||
return new_v;
|
||||
}
|
||||
|
||||
private:
|
||||
using MemLoadEntry = std::pair<const Var*, const Expr*>;
|
||||
using MemLoadEntry = std::pair<Var*, Expr*>;
|
||||
using MemLoadList = std::vector<MemLoadEntry>;
|
||||
using MemoryLoadStack = std::vector<MemLoadList>;
|
||||
|
||||
@ -659,7 +654,7 @@ class PrioritizeLoad : public IRMutator {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto& pair : load_list) {
|
||||
for (auto& pair : load_list) {
|
||||
Stmt* news = new Let(pair.first, pair.second);
|
||||
block->insert_stmt_before(news, last);
|
||||
}
|
||||
@ -678,9 +673,9 @@ class PrioritizeLoad : public IRMutator {
|
||||
// }
|
||||
// int v2 = v + 2;
|
||||
int nested_if_then_else_{0};
|
||||
const Store* nested_store_{nullptr};
|
||||
Store* nested_store_{nullptr};
|
||||
bool nested_let_{false};
|
||||
std::unordered_set<const Var*> thread_local_bufs_;
|
||||
std::unordered_set<Var*> thread_local_bufs_;
|
||||
};
|
||||
|
||||
std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
|
||||
@ -716,9 +711,9 @@ bool GPUMetaVarRewriter::isFullExtent() {
|
||||
return true;
|
||||
}
|
||||
|
||||
Stmt* GPUMetaVarRewriter::mutate(const For* v) {
|
||||
Stmt* GPUMetaVarRewriter::mutate(For* v) {
|
||||
Stmt* body = v->body();
|
||||
const Expr* old_reach = nullptr;
|
||||
Expr* old_reach = nullptr;
|
||||
const LoopOptions& loop_options = v->loop_options();
|
||||
if (loop_options.is_gpu_block_index()) {
|
||||
int gpu_block_index = loop_options.gpu_block_index();
|
||||
@ -737,7 +732,7 @@ Stmt* GPUMetaVarRewriter::mutate(const For* v) {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
const Var* metaVar = gpu_block_vars_[gpu_block_index];
|
||||
Var* metaVar = gpu_block_vars_[gpu_block_index];
|
||||
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
|
||||
} else if (loop_options.is_gpu_thread_index()) {
|
||||
int gpu_thread_index = loop_options.gpu_thread_index();
|
||||
@ -756,7 +751,7 @@ Stmt* GPUMetaVarRewriter::mutate(const For* v) {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
const Var* metaVar = gpu_thread_vars_[gpu_thread_index];
|
||||
Var* metaVar = gpu_thread_vars_[gpu_thread_index];
|
||||
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
|
||||
}
|
||||
|
||||
@ -776,7 +771,7 @@ Stmt* GPUMetaVarRewriter::mutate(const For* v) {
|
||||
return v->cloneWithNewBody(body);
|
||||
}
|
||||
|
||||
Stmt* GPUMetaVarRewriter::mutate(const Block* v) {
|
||||
Stmt* GPUMetaVarRewriter::mutate(Block* v) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<Segment> innerSegments;
|
||||
Segment current;
|
||||
@ -891,7 +886,7 @@ Stmt* GPUMetaVarRewriter::mutate(const Block* v) {
|
||||
|
||||
static std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const std::vector<const Expr*>& exprs) {
|
||||
const std::vector<Expr*>& exprs) {
|
||||
size_t i = 0;
|
||||
for (auto expr : exprs) {
|
||||
if (i++ > 0) {
|
||||
@ -983,7 +978,7 @@ void CudaCodeGen::Initialize() {
|
||||
os() << ", ";
|
||||
}
|
||||
const BufferArg& buffer_arg = buffer_args[i];
|
||||
const Var* var = buffer_arg.var();
|
||||
Var* var = buffer_arg.var();
|
||||
Dtype dtype = buffer_arg.dtype();
|
||||
|
||||
os() << printer_->dtypeToCppString(dtype)
|
||||
@ -991,9 +986,9 @@ void CudaCodeGen::Initialize() {
|
||||
<< name_manager()->get_unique_name(var);
|
||||
}
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
const Var* rand_seed;
|
||||
Var* rand_seed;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
const Var* rand_offset;
|
||||
Var* rand_offset;
|
||||
if (has_random_) {
|
||||
// TODO: switch to kUint64 when it is available.
|
||||
rand_seed = new Var("rand_seed", kInt);
|
||||
@ -1006,11 +1001,11 @@ void CudaCodeGen::Initialize() {
|
||||
os() << std::endl;
|
||||
|
||||
if (has_random_) {
|
||||
const Var* idx = new Var("idx", kInt);
|
||||
Var* idx = new Var("idx", kInt);
|
||||
os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;"
|
||||
<< std::endl;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
const Var* rand_func = printer_->rand_func();
|
||||
Var* rand_func = printer_->rand_func();
|
||||
os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
|
||||
<< *rand_offset << ");" << std::endl;
|
||||
os() << std::endl;
|
||||
@ -1041,7 +1036,7 @@ void CudaCodeGen::Initialize() {
|
||||
os() << "}";
|
||||
|
||||
// Check that all block extents had been set.
|
||||
const std::vector<const Expr*>& gpu_block_extents =
|
||||
const std::vector<Expr*>& gpu_block_extents =
|
||||
metavar_rewriter_->gpu_block_extents();
|
||||
for (size_t i = 0; i < gpu_block_extents.size(); i++) {
|
||||
if (!gpu_block_extents[i]) {
|
||||
@ -1067,9 +1062,9 @@ void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
|
||||
auto const& buffer_args = this->buffer_args();
|
||||
|
||||
// TODO: move as much of this into the constructors.
|
||||
const std::vector<const Expr*>& gpu_block_extents =
|
||||
const std::vector<Expr*>& gpu_block_extents =
|
||||
metavar_rewriter_->gpu_block_extents();
|
||||
const std::vector<const Expr*>& gpu_thread_extents =
|
||||
const std::vector<Expr*>& gpu_thread_extents =
|
||||
metavar_rewriter_->gpu_thread_extents();
|
||||
if (gpu_block_extents.size() > 3 || gpu_thread_extents.size() > 3) {
|
||||
throw malformed_input(
|
||||
@ -1148,7 +1143,7 @@ void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
|
||||
ptr_to_args[buffer_args.size() + 1] = &rand_offset;
|
||||
}
|
||||
|
||||
const auto prior_device = at::cuda::current_device();
|
||||
auto prior_device = at::cuda::current_device();
|
||||
if (prior_device != this->device().index()) {
|
||||
at::cuda::set_device(this->device().index());
|
||||
}
|
||||
@ -1207,7 +1202,7 @@ void CudaCodeGen::CompileToNVRTC(
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx));
|
||||
// Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
|
||||
// properly in some scenarios
|
||||
const auto prior_device = at::cuda::current_device();
|
||||
auto prior_device = at::cuda::current_device();
|
||||
if (prior_device != this->device().index()) {
|
||||
at::cuda::set_device(this->device().index());
|
||||
}
|
||||
@ -1259,8 +1254,7 @@ void CudaCodeGen::CompileToNVRTC(
|
||||
"--std=c++14", compute.c_str(), "-default-device"};
|
||||
#endif
|
||||
|
||||
const auto result =
|
||||
nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
|
||||
auto result = nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
|
||||
if (result != NVRTC_SUCCESS) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
size_t logsize;
|
||||
@ -1284,15 +1278,14 @@ void CudaCodeGen::CompileToNVRTC(
|
||||
#if CUDA_VERSION >= 11010
|
||||
// compile_to_sass determines whether we are generating SASS or PTX, hence
|
||||
// the different API.
|
||||
const auto getSize = compile_to_sass
|
||||
auto getSize = compile_to_sass
|
||||
? at::globalContext().getNVRTC().nvrtcGetCUBINSize
|
||||
: at::globalContext().getNVRTC().nvrtcGetPTXSize;
|
||||
const auto getFunc = compile_to_sass
|
||||
? at::globalContext().getNVRTC().nvrtcGetCUBIN
|
||||
: at::globalContext().getNVRTC().nvrtcGetPTX;
|
||||
auto getFunc = compile_to_sass ? at::globalContext().getNVRTC().nvrtcGetCUBIN
|
||||
: at::globalContext().getNVRTC().nvrtcGetPTX;
|
||||
#else
|
||||
const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
|
||||
const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
|
||||
auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
|
||||
auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
|
||||
#endif
|
||||
AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
|
||||
ptx.resize(ptx_size);
|
||||
|
@ -26,41 +26,41 @@ class CudaAnalysis : public IRVisitor {
|
||||
gpu_block_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
|
||||
gpu_thread_extents_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
|
||||
}
|
||||
bool is_buf_store_target(const Buf* buf) const {
|
||||
bool is_buf_store_target(Buf* buf) const {
|
||||
return store_targets_.count(buf) > 0;
|
||||
}
|
||||
|
||||
const std::unordered_set<const Var*>& thread_local_bufs() const {
|
||||
const std::unordered_set<Var*>& thread_local_bufs() const {
|
||||
return thread_local_bufs_;
|
||||
}
|
||||
|
||||
const std::unordered_set<const Var*>& cross_block_bufs() const {
|
||||
const std::unordered_set<Var*>& cross_block_bufs() const {
|
||||
return cross_block_bufs_;
|
||||
}
|
||||
|
||||
const std::vector<const Expr*>& gpu_block_extents() const {
|
||||
const std::vector<Expr*>& gpu_block_extents() const {
|
||||
return gpu_block_extents_;
|
||||
}
|
||||
|
||||
const std::vector<const Expr*>& gpu_thread_extents() const {
|
||||
const std::vector<Expr*>& gpu_thread_extents() const {
|
||||
return gpu_thread_extents_;
|
||||
}
|
||||
|
||||
private:
|
||||
void visit(const Store* v) override {
|
||||
void visit(Store* v) override {
|
||||
store_targets_.insert(v->buf());
|
||||
}
|
||||
|
||||
void visit(const Allocate* v) override;
|
||||
void visit(const Free* v) override;
|
||||
void visit(const For* v) override;
|
||||
void visit(Allocate* v) override;
|
||||
void visit(Free* v) override;
|
||||
void visit(For* v) override;
|
||||
|
||||
std::unordered_set<const Buf*> store_targets_;
|
||||
std::unordered_set<const Var*> thread_local_bufs_;
|
||||
std::unordered_set<const Var*> cross_block_bufs_;
|
||||
std::unordered_set<Buf*> store_targets_;
|
||||
std::unordered_set<Var*> thread_local_bufs_;
|
||||
std::unordered_set<Var*> cross_block_bufs_;
|
||||
|
||||
std::vector<const Expr*> gpu_block_extents_;
|
||||
std::vector<const Expr*> gpu_thread_extents_;
|
||||
std::vector<Expr*> gpu_block_extents_;
|
||||
std::vector<Expr*> gpu_thread_extents_;
|
||||
};
|
||||
|
||||
// An IRMutator that replaces binding loop options with Cuda metavars, and masks
|
||||
@ -87,22 +87,22 @@ class GPUMetaVarRewriter : public IRMutator {
|
||||
current_thread_reach_ = {new IntImm(1), new IntImm(1), new IntImm(1)};
|
||||
}
|
||||
|
||||
Stmt* mutate(const For* v) override;
|
||||
Stmt* mutate(const Block* v) override;
|
||||
Stmt* mutate(For* v) override;
|
||||
Stmt* mutate(Block* v) override;
|
||||
|
||||
const std::vector<const Var*>& gpu_block_vars() const {
|
||||
const std::vector<Var*>& gpu_block_vars() const {
|
||||
return gpu_block_vars_;
|
||||
}
|
||||
|
||||
const std::vector<const Var*>& gpu_thread_vars() const {
|
||||
const std::vector<Var*>& gpu_thread_vars() const {
|
||||
return gpu_thread_vars_;
|
||||
}
|
||||
|
||||
const std::vector<const Expr*>& gpu_block_extents() const {
|
||||
const std::vector<Expr*>& gpu_block_extents() const {
|
||||
return cuda_analysis_->gpu_block_extents();
|
||||
}
|
||||
|
||||
const std::vector<const Expr*>& gpu_thread_extents() const {
|
||||
const std::vector<Expr*>& gpu_thread_extents() const {
|
||||
return cuda_analysis_->gpu_thread_extents();
|
||||
}
|
||||
|
||||
@ -136,11 +136,11 @@ class GPUMetaVarRewriter : public IRMutator {
|
||||
// parameters.
|
||||
bool isFullExtent();
|
||||
|
||||
std::vector<const Var*> gpu_block_vars_;
|
||||
std::vector<const Var*> gpu_thread_vars_;
|
||||
std::vector<Var*> gpu_block_vars_;
|
||||
std::vector<Var*> gpu_thread_vars_;
|
||||
|
||||
std::vector<const Expr*> current_block_reach_;
|
||||
std::vector<const Expr*> current_thread_reach_;
|
||||
std::vector<Expr*> current_block_reach_;
|
||||
std::vector<Expr*> current_thread_reach_;
|
||||
|
||||
const CudaAnalysis* cuda_analysis_;
|
||||
};
|
||||
@ -158,24 +158,24 @@ class CudaPrinter : public IRPrinter {
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const Cast* v) override;
|
||||
void visit(const Intrinsics* v) override;
|
||||
void visit(const For* v) override;
|
||||
void visit(Cast* v) override;
|
||||
void visit(Intrinsics* v) override;
|
||||
void visit(For* v) override;
|
||||
|
||||
void visit(const Load* v) override;
|
||||
void visit(const Store* v) override;
|
||||
void visit(const AtomicAdd* v) override;
|
||||
void visit(const Max* v) override;
|
||||
void visit(const Min* v) override;
|
||||
void visit(const IfThenElse* v) override;
|
||||
void visit(const Block* v) override;
|
||||
void visit(const Allocate* v) override;
|
||||
void visit(const Free* v) override;
|
||||
void visit(const Let* v) override;
|
||||
void visit(Load* v) override;
|
||||
void visit(Store* v) override;
|
||||
void visit(AtomicAdd* v) override;
|
||||
void visit(Max* v) override;
|
||||
void visit(Min* v) override;
|
||||
void visit(IfThenElse* v) override;
|
||||
void visit(Block* v) override;
|
||||
void visit(Allocate* v) override;
|
||||
void visit(Free* v) override;
|
||||
void visit(Let* v) override;
|
||||
|
||||
void visit(const ExternalCall* v) override;
|
||||
void visit(ExternalCall* v) override;
|
||||
|
||||
const Var* rand_func() const {
|
||||
Var* rand_func() const {
|
||||
return rand_func_;
|
||||
}
|
||||
|
||||
@ -185,10 +185,10 @@ class CudaPrinter : public IRPrinter {
|
||||
using IRPrinter::visit;
|
||||
|
||||
private:
|
||||
const Var* rand_func_;
|
||||
Var* rand_func_;
|
||||
const CudaAnalysis* cuda_analysis_;
|
||||
|
||||
void print_flat_alloc(const Allocate* alloc);
|
||||
void print_flat_alloc(Allocate* alloc);
|
||||
};
|
||||
|
||||
// Construct Cuda C from the buffer and tensor input, and invoke the kernel
|
||||
@ -233,11 +233,11 @@ class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen {
|
||||
c10::optional<c10::Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt) override;
|
||||
|
||||
const std::vector<const Expr*>& gpu_block_extents() const {
|
||||
const std::vector<Expr*>& gpu_block_extents() const {
|
||||
return cuda_analysis_->gpu_block_extents();
|
||||
}
|
||||
|
||||
const std::vector<const Expr*>& gpu_thread_extents() const {
|
||||
const std::vector<Expr*>& gpu_thread_extents() const {
|
||||
return cuda_analysis_->gpu_thread_extents();
|
||||
}
|
||||
|
||||
|
@ -59,14 +59,14 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
|
||||
~SimpleIREvaluatorImpl() override = default;
|
||||
|
||||
void bindBuf(const Buf* buf, void* ptr) {
|
||||
void bindBuf(Buf* buf, void* ptr) {
|
||||
buffer_mapping_[buf] = ptr;
|
||||
}
|
||||
void bindVar(const Var* var, const Value& val) {
|
||||
void bindVar(Var* var, const Value& val) {
|
||||
eval_context_[var] = val;
|
||||
}
|
||||
|
||||
Value evaluateExpr(const Expr* e) {
|
||||
Value evaluateExpr(Expr* e) {
|
||||
e->accept(this);
|
||||
return value_;
|
||||
}
|
||||
@ -81,45 +81,45 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
internal_buffers_.clear();
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Add* v) override {
|
||||
TORCH_API void visit(Add* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Sub* v) override {
|
||||
TORCH_API void visit(Sub* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Mul* v) override {
|
||||
TORCH_API void visit(Mul* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Div* v) override {
|
||||
TORCH_API void visit(Div* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Mod* v) override {
|
||||
TORCH_API void visit(Mod* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Max* v) override {
|
||||
TORCH_API void visit(Max* v) override {
|
||||
visit_binary_op(v, v->propagate_nans());
|
||||
}
|
||||
TORCH_API void visit(const Min* v) override {
|
||||
TORCH_API void visit(Min* v) override {
|
||||
visit_binary_op(v, v->propagate_nans());
|
||||
}
|
||||
|
||||
TORCH_API void visit(const And* v) override {
|
||||
TORCH_API void visit(And* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Or* v) override {
|
||||
TORCH_API void visit(Or* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Xor* v) override {
|
||||
TORCH_API void visit(Xor* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Lshift* v) override {
|
||||
TORCH_API void visit(Lshift* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Rshift* v) override {
|
||||
TORCH_API void visit(Rshift* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
|
||||
void visit(const CompareSelect* v) override {
|
||||
void visit(CompareSelect* v) override {
|
||||
visit_compare_select_op(v, v->compare_select_op());
|
||||
}
|
||||
|
||||
@ -156,7 +156,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
std::vector<T> lhs_v = lhs.as_vec<T>();
|
||||
std::vector<T> rhs_v = rhs.as_vec<T>();
|
||||
std::vector<T> result_v(lhs_v.size());
|
||||
for (const auto i : c10::irange(lhs_v.size())) {
|
||||
for (auto i : c10::irange(lhs_v.size())) {
|
||||
switch (op_type) {
|
||||
case IRNodeType::kAdd:
|
||||
result_v[i] = lhs_v[i] + rhs_v[i];
|
||||
@ -195,7 +195,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
std::vector<T> lhs_v = lhs.as_vec<T>();
|
||||
std::vector<T> rhs_v = rhs.as_vec<T>();
|
||||
std::vector<T> result_v(lhs_v.size());
|
||||
for (const auto i : c10::irange(lhs_v.size())) {
|
||||
for (auto i : c10::irange(lhs_v.size())) {
|
||||
switch (op_type) {
|
||||
case IRNodeType::kAnd:
|
||||
result_v[i] = lhs_v[i] & rhs_v[i];
|
||||
@ -222,7 +222,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
std::vector<T> lhs_v = lhs.as_vec<T>();
|
||||
std::vector<T> rhs_v = rhs.as_vec<T>();
|
||||
std::vector<T> result_v(lhs_v.size());
|
||||
for (const auto i : c10::irange(lhs_v.size())) {
|
||||
for (auto i : c10::irange(lhs_v.size())) {
|
||||
switch (op_type) {
|
||||
case IRNodeType::kLshift: {
|
||||
typename std::make_unsigned<T>::type a =
|
||||
@ -253,7 +253,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
std::vector<R> ret_val1_v = retval1.as_vec<R>();
|
||||
std::vector<R> ret_val2_v = retval2.as_vec<R>();
|
||||
std::vector<R> result_v(lhs_v.size());
|
||||
for (const auto i : c10::irange(lhs_v.size())) {
|
||||
for (auto i : c10::irange(lhs_v.size())) {
|
||||
switch (cmp_op) {
|
||||
case CompareSelectOperation::kEQ:
|
||||
result_v[i] = (lhs_v[i] == rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i];
|
||||
@ -282,7 +282,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void visit_binary_op(const BinaryOpNode<Op>* v, bool option = false) {
|
||||
void visit_binary_op(BinaryOpNode<Op>* v, bool option = false) {
|
||||
v->lhs()->accept(this);
|
||||
Value lhs_v = value_;
|
||||
v->rhs()->accept(this);
|
||||
@ -365,7 +365,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
|
||||
void visit_compare_select_op(
|
||||
const CompareSelect* v,
|
||||
CompareSelect* v,
|
||||
CompareSelectOperation cmp_op) {
|
||||
v->lhs()->accept(this);
|
||||
Value lhs_v = value_;
|
||||
@ -401,8 +401,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
|
||||
#undef IMM_VISIT
|
||||
|
||||
TORCH_API void visit(const Block* v) override {
|
||||
const Block* last = scope_;
|
||||
TORCH_API void visit(Block* v) override {
|
||||
Block* last = scope_;
|
||||
scope_ = v;
|
||||
for (Stmt* s : v->stmts()) {
|
||||
s->accept(this);
|
||||
@ -410,7 +410,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
|
||||
auto it = var_by_scope_.find(v);
|
||||
if (it != var_by_scope_.end()) {
|
||||
for (const Expr* v : it->second) {
|
||||
for (Expr* v : it->second) {
|
||||
eval_context_.erase(v);
|
||||
}
|
||||
var_by_scope_.erase(it);
|
||||
@ -419,7 +419,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
scope_ = last;
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Var* v) override {
|
||||
TORCH_API void visit(Var* v) override {
|
||||
auto iter = eval_context_.find(v);
|
||||
if (iter == eval_context_.end()) {
|
||||
throw malformed_input("could not find Var in context", v);
|
||||
@ -456,8 +456,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Cast* v) override {
|
||||
const Expr* src_value = v->src_value();
|
||||
TORCH_API void visit(Cast* v) override {
|
||||
Expr* src_value = v->src_value();
|
||||
src_value->accept(this);
|
||||
Dtype dst_dtype = v->dtype();
|
||||
Dtype src_dtype = src_value->dtype();
|
||||
@ -507,8 +507,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const BitCast* v) override {
|
||||
const Expr* src_value = v->src_value();
|
||||
TORCH_API void visit(BitCast* v) override {
|
||||
Expr* src_value = v->src_value();
|
||||
src_value->accept(this);
|
||||
Dtype dst_dtype = v->dtype();
|
||||
Dtype src_dtype = src_value->dtype();
|
||||
@ -530,8 +530,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const For* v) override {
|
||||
const Expr* var_node = v->var();
|
||||
TORCH_API void visit(For* v) override {
|
||||
Expr* var_node = v->var();
|
||||
v->start()->accept(this);
|
||||
int start = value_.as<int>();
|
||||
v->stop()->accept(this);
|
||||
@ -549,7 +549,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
eval_context_.erase(var_node);
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Ramp* v) override {
|
||||
TORCH_API void visit(Ramp* v) override {
|
||||
v->base()->accept(this);
|
||||
int base = value().as<int>();
|
||||
v->stride()->accept(this);
|
||||
@ -557,14 +557,14 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
int lanes = v->lanes();
|
||||
|
||||
std::vector<int> values(lanes);
|
||||
for (const auto i : c10::irange(lanes)) {
|
||||
for (auto i : c10::irange(lanes)) {
|
||||
values[i] = base + i * stride;
|
||||
}
|
||||
|
||||
value_ = Value(values);
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Broadcast* v) override {
|
||||
TORCH_API void visit(Broadcast* v) override {
|
||||
v->value()->accept(this);
|
||||
Value value = this->value();
|
||||
int lanes = v->lanes();
|
||||
@ -581,7 +581,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const IfThenElse* v) override {
|
||||
TORCH_API void visit(IfThenElse* v) override {
|
||||
v->condition()->accept(this);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
bool cond_v;
|
||||
@ -605,26 +605,26 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Load* v) override {
|
||||
TORCH_API void visit(Load* v) override {
|
||||
auto iter = buffer_mapping_.find(v->buf());
|
||||
if (iter == buffer_mapping_.end()) {
|
||||
throw malformed_input("could not find base node in Load", v);
|
||||
}
|
||||
void* ptr = iter->second;
|
||||
|
||||
const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
|
||||
Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
|
||||
flat_idx->accept(this);
|
||||
std::vector<int> index = value().as_vec<int>();
|
||||
ScalarType v_sdtype = v->dtype().scalar_type();
|
||||
switch (v_sdtype) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
case ScalarType::Name: { \
|
||||
Type* ptr##Name = static_cast<Type*>(ptr); \
|
||||
std::vector<Type> v(index.size()); \
|
||||
for (const auto i : c10::irange(index.size())) { \
|
||||
v[i] = ptr##Name[index[i]]; \
|
||||
} \
|
||||
value_ = Value(v); \
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
case ScalarType::Name: { \
|
||||
Type* ptr##Name = static_cast<Type*>(ptr); \
|
||||
std::vector<Type> v(index.size()); \
|
||||
for (auto i : c10::irange(index.size())) { \
|
||||
v[i] = ptr##Name[index[i]]; \
|
||||
} \
|
||||
value_ = Value(v); \
|
||||
} break;
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
|
||||
#undef TYPE_CASE
|
||||
@ -633,7 +633,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Store* v) override {
|
||||
TORCH_API void visit(Store* v) override {
|
||||
auto iter = buffer_mapping_.find(v->buf());
|
||||
if (iter == buffer_mapping_.end()) {
|
||||
throw malformed_input("could not find base node in Store", v);
|
||||
@ -641,7 +641,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
|
||||
void* ptr = iter->second;
|
||||
|
||||
const Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
|
||||
Expr* flat_idx = flatten_index(v->buf()->dims(), v->indices());
|
||||
flat_idx->accept(this);
|
||||
std::vector<int> index = value().as_vec<int>();
|
||||
ScalarType v_sdtype = v->value()->dtype().scalar_type();
|
||||
@ -655,7 +655,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
throw malformed_input("value size mismatch in Store", v); \
|
||||
} \
|
||||
Type* ptr##Name = static_cast<Type*>(ptr); \
|
||||
for (const auto i : c10::irange(index.size())) { \
|
||||
for (auto i : c10::irange(index.size())) { \
|
||||
ptr##Name[index[i]] = value[i]; \
|
||||
} \
|
||||
} break;
|
||||
@ -666,13 +666,13 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const ExternalCall* v) override {
|
||||
void visit(ExternalCall* v) override {
|
||||
auto& func_registry = getNNCFunctionRegistry();
|
||||
if (!func_registry.count(v->func_name())) {
|
||||
throw unimplemented_lowering(v);
|
||||
}
|
||||
|
||||
std::vector<const Buf*> bufs(v->buf_args());
|
||||
std::vector<Buf*> bufs(v->buf_args());
|
||||
bufs.insert(bufs.begin(), v->buf());
|
||||
|
||||
std::vector<void*> buf_ptrs;
|
||||
@ -681,7 +681,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
std::vector<int8_t> buf_dtypes;
|
||||
std::vector<int64_t> extra_args;
|
||||
|
||||
for (const Buf* b : bufs) {
|
||||
for (Buf* b : bufs) {
|
||||
auto iter = buffer_mapping_.find(b);
|
||||
if (iter == buffer_mapping_.end()) {
|
||||
throw malformed_input("could not find buf", v);
|
||||
@ -690,12 +690,12 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
buf_ptrs.push_back(iter->second);
|
||||
buf_ranks.push_back(b->dims().size());
|
||||
buf_dtypes.push_back((int8_t)b->dtype().scalar_type());
|
||||
for (const Expr* dim_expr : b->dims()) {
|
||||
for (Expr* dim_expr : b->dims()) {
|
||||
dim_expr->accept(this);
|
||||
buf_dims.push_back(value().as<int>());
|
||||
}
|
||||
}
|
||||
for (const Expr* a : v->args()) {
|
||||
for (Expr* a : v->args()) {
|
||||
a->accept(this);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int64_t val;
|
||||
@ -722,9 +722,9 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
|
||||
template <typename TReturn, typename TInput>
|
||||
void visit_intrinsics_helper(const Intrinsics* v) {
|
||||
void visit_intrinsics_helper(Intrinsics* v) {
|
||||
std::vector<Value> values(v->nparams());
|
||||
for (const auto i : c10::irange(v->nparams())) {
|
||||
for (auto i : c10::irange(v->nparams())) {
|
||||
v->param(i)->accept(this);
|
||||
values[i] = this->value();
|
||||
}
|
||||
@ -746,18 +746,18 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
|
||||
std::vector<TReturn> result(v1.size(), -1);
|
||||
if (values.size() == 1ULL) {
|
||||
for (const auto i : c10::irange(v1.size())) {
|
||||
for (auto i : c10::irange(v1.size())) {
|
||||
result[i] = compute_intrinsics<TReturn>(v->op_type(), v1[i]);
|
||||
}
|
||||
} else {
|
||||
for (const auto i : c10::irange(v1.size())) {
|
||||
for (auto i : c10::irange(v1.size())) {
|
||||
result[i] = compute_intrinsics<TReturn>(v->op_type(), v1[i], v2[i]);
|
||||
}
|
||||
}
|
||||
value_ = Value(result);
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Intrinsics* v) override {
|
||||
TORCH_API void visit(Intrinsics* v) override {
|
||||
auto ty = v->dtype().scalar_type();
|
||||
if (v->op_type() == kIsNan) {
|
||||
auto inp_dtype = v->params().at(0)->dtype().scalar_type();
|
||||
@ -782,15 +782,15 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const Allocate* v) override {
|
||||
const Buf* b = v->buf();
|
||||
std::vector<const Expr*> dims = b->dims();
|
||||
void visit(Allocate* v) override {
|
||||
Buf* b = v->buf();
|
||||
std::vector<Expr*> dims = b->dims();
|
||||
int total_byte_size = b->dtype().byte_size();
|
||||
for (auto& dim : dims) {
|
||||
dim->accept(this);
|
||||
total_byte_size *= value_.as<int>();
|
||||
}
|
||||
const auto int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int);
|
||||
auto int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int);
|
||||
std::unique_ptr<std::vector<int>> buffer(new std::vector<int>(int_count));
|
||||
auto iter = buffer_mapping_.find(b);
|
||||
if (iter != buffer_mapping_.end() && iter->second != nullptr) {
|
||||
@ -802,8 +802,8 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
internal_buffers_.insert(std::make_pair(b, std::move(buffer)));
|
||||
}
|
||||
|
||||
void visit(const Free* v) override {
|
||||
const Buf* b = v->buf();
|
||||
void visit(Free* v) override {
|
||||
Buf* b = v->buf();
|
||||
int count = internal_buffers_.erase(b);
|
||||
if (count == 0) {
|
||||
throw std::runtime_error(
|
||||
@ -813,12 +813,12 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
buffer_mapping_.erase(b);
|
||||
}
|
||||
|
||||
void visit(const Let* v) override {
|
||||
void visit(Let* v) override {
|
||||
var_by_scope_[scope_].push_back(v->var());
|
||||
bindVar(v->var(), evaluateExpr(v->value()));
|
||||
}
|
||||
|
||||
void visit(const Cond* v) override {
|
||||
void visit(Cond* v) override {
|
||||
v->condition()->accept(this);
|
||||
if (value().as<int>()) {
|
||||
if (v->true_stmt()) {
|
||||
@ -950,12 +950,11 @@ class SimpleIREvaluatorImpl : public IRVisitor {
|
||||
}
|
||||
|
||||
Value value_;
|
||||
const Block* scope_;
|
||||
std::unordered_map<const Expr*, Value> eval_context_;
|
||||
std::unordered_map<const Block*, std::vector<const Expr*>> var_by_scope_;
|
||||
std::unordered_map<const Buf*, void*> buffer_mapping_;
|
||||
std::unordered_map<const Buf*, std::unique_ptr<std::vector<int>>>
|
||||
internal_buffers_;
|
||||
Block* scope_;
|
||||
std::unordered_map<Expr*, Value> eval_context_;
|
||||
std::unordered_map<Block*, std::vector<Expr*>> var_by_scope_;
|
||||
std::unordered_map<Buf*, void*> buffer_mapping_;
|
||||
std::unordered_map<Buf*, std::unique_ptr<std::vector<int>>> internal_buffers_;
|
||||
};
|
||||
|
||||
SimpleIREvaluator::SimpleIREvaluator(
|
||||
@ -984,7 +983,7 @@ void SimpleIREvaluator::call_raw(const std::vector<void*>& args) {
|
||||
if (args.size() != buffer_args().size()) {
|
||||
throw malformed_input("bad args in IREvaluator call");
|
||||
}
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
for (auto i : c10::irange(args.size())) {
|
||||
bindArg(buffer_args()[i], args[i]);
|
||||
}
|
||||
stmt()->accept(&*impl_);
|
||||
@ -1012,7 +1011,7 @@ void SimpleIREvaluator::bindArg(const BufferArg& bufArg, void* data) {
|
||||
}
|
||||
}
|
||||
|
||||
void SimpleIREvaluator::bindVar(const Var* v, const Expr* e) {
|
||||
void SimpleIREvaluator::bindVar(Var* v, Expr* e) {
|
||||
impl_->bindVar(v, impl_->evaluateExpr(e));
|
||||
}
|
||||
|
||||
|
@ -114,7 +114,7 @@ class TORCH_API SimpleIREvaluator : public CodeGen {
|
||||
call(args);
|
||||
}
|
||||
|
||||
void bindVar(const Var* v, const Expr* e);
|
||||
void bindVar(Var* v, Expr* e);
|
||||
Value value() const;
|
||||
|
||||
private:
|
||||
@ -145,8 +145,8 @@ class ExprEval {
|
||||
std::vector<BufferArg> buffer_args_extended = buffer_args;
|
||||
Placeholder ret_buf("ret_val", dtype_, {1});
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Expr*> indices;
|
||||
const Expr* zero = new IntImm(0);
|
||||
std::vector<Expr*> indices;
|
||||
Expr* zero = new IntImm(0);
|
||||
for (size_t i = 0; i < ret_buf.data()->ndim(); i++) {
|
||||
indices.push_back(zero);
|
||||
}
|
||||
@ -167,7 +167,7 @@ class ExprEval {
|
||||
call(call_args);
|
||||
}
|
||||
|
||||
void bindVar(const Var* v, const Expr* e) {
|
||||
void bindVar(Var* v, Expr* e) {
|
||||
codegen_->bindVar(v, e);
|
||||
}
|
||||
|
||||
@ -253,7 +253,7 @@ class ExprEval {
|
||||
Value ret_value_;
|
||||
};
|
||||
|
||||
inline const Expr* Substitute(const Expr* expr, const VarMapping& var_mapping) {
|
||||
inline Expr* Substitute(Expr* expr, const VarMapping& var_mapping) {
|
||||
VarSubMutator var_sub(var_mapping);
|
||||
return expr->accept_mutator(&var_sub);
|
||||
}
|
||||
|
@ -43,9 +43,9 @@ class unimplemented_lowering : public std::runtime_error {
|
||||
public:
|
||||
explicit unimplemented_lowering()
|
||||
: std::runtime_error("UNIMPLEMENTED LOWERING") {}
|
||||
explicit unimplemented_lowering(const Expr* expr)
|
||||
explicit unimplemented_lowering(Expr* expr)
|
||||
: std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(expr)) {}
|
||||
explicit unimplemented_lowering(const Stmt* stmt)
|
||||
explicit unimplemented_lowering(Stmt* stmt)
|
||||
: std::runtime_error("UNIMPLEMENTED LOWERING: " + std::to_string(stmt)) {}
|
||||
};
|
||||
|
||||
@ -54,14 +54,14 @@ class malformed_input : public std::runtime_error {
|
||||
explicit malformed_input() : std::runtime_error("MALFORMED INPUT") {}
|
||||
explicit malformed_input(const std::string& err)
|
||||
: std::runtime_error("MALFORMED INPUT: " + err) {}
|
||||
explicit malformed_input(const Expr* expr)
|
||||
explicit malformed_input(Expr* expr)
|
||||
: std::runtime_error("MALFORMED INPUT: " + std::to_string(expr)) {}
|
||||
explicit malformed_input(const std::string& err, const Expr* expr)
|
||||
explicit malformed_input(const std::string& err, Expr* expr)
|
||||
: std::runtime_error(
|
||||
"MALFORMED INPUT: " + err + " - " + std::to_string(expr)) {}
|
||||
explicit malformed_input(const Stmt* stmt)
|
||||
explicit malformed_input(Stmt* stmt)
|
||||
: std::runtime_error("MALFORMED INPUT: " + std::to_string(stmt)) {}
|
||||
explicit malformed_input(const std::string& err, const Stmt* stmt)
|
||||
explicit malformed_input(const std::string& err, Stmt* stmt)
|
||||
: std::runtime_error(
|
||||
"MALFORMED INPUT: " + err + " - " + std::to_string(stmt)) {}
|
||||
};
|
||||
@ -71,14 +71,14 @@ class malformed_ir : public std::runtime_error {
|
||||
explicit malformed_ir() : std::runtime_error("MALFORMED IR") {}
|
||||
explicit malformed_ir(const std::string& err)
|
||||
: std::runtime_error("MALFORMED IR: " + err) {}
|
||||
explicit malformed_ir(const Expr* expr)
|
||||
explicit malformed_ir(Expr* expr)
|
||||
: std::runtime_error("MALFORMED IR: " + std::to_string(expr)) {}
|
||||
explicit malformed_ir(const std::string& err, const Expr* expr)
|
||||
explicit malformed_ir(const std::string& err, Expr* expr)
|
||||
: std::runtime_error(
|
||||
"MALFORMED IR: " + err + " - " + std::to_string(expr)) {}
|
||||
explicit malformed_ir(const Stmt* stmt)
|
||||
explicit malformed_ir(Stmt* stmt)
|
||||
: std::runtime_error("MALFORMED IR: " + std::to_string(stmt)) {}
|
||||
explicit malformed_ir(const std::string& err, const Stmt* stmt)
|
||||
explicit malformed_ir(const std::string& err, Stmt* stmt)
|
||||
: std::runtime_error(
|
||||
"MALFORMED IR: " + err + " - " + std::to_string(stmt)) {}
|
||||
};
|
||||
|
@ -42,8 +42,8 @@ class TORCH_API Expr : public KernelScopedObject {
|
||||
Dtype dtype() const {
|
||||
return dtype_;
|
||||
}
|
||||
virtual void accept(IRVisitor* visitor) const = 0;
|
||||
virtual const Expr* accept_mutator(IRMutator* mutator) const = 0;
|
||||
virtual void accept(IRVisitor* visitor) = 0;
|
||||
virtual Expr* accept_mutator(IRMutator* mutator) = 0;
|
||||
|
||||
IRNodeType expr_type() const {
|
||||
return expr_type_;
|
||||
@ -64,10 +64,10 @@ template <class Op, class Base = Expr>
|
||||
class ExprNode : public Base {
|
||||
public:
|
||||
using ExprNodeBase = ExprNode<Op>;
|
||||
void accept(IRVisitor* visitor) const override {
|
||||
visitor->visit(static_cast<const Op*>(this));
|
||||
void accept(IRVisitor* visitor) override {
|
||||
visitor->visit(static_cast<Op*>(this));
|
||||
}
|
||||
const Expr* accept_mutator(IRMutator* mutator) const override;
|
||||
Expr* accept_mutator(IRMutator* mutator) override;
|
||||
// pass the constructor to the base class
|
||||
using Base::Base;
|
||||
};
|
||||
@ -77,7 +77,7 @@ class ExprNode : public Base {
|
||||
class TORCH_API ExprHandle {
|
||||
public:
|
||||
ExprHandle() = default;
|
||||
explicit ExprHandle(const Expr* node)
|
||||
explicit ExprHandle(Expr* node)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
: base_expr_node_(const_cast<Expr*>(node)) {}
|
||||
|
||||
@ -85,7 +85,7 @@ class TORCH_API ExprHandle {
|
||||
return base_expr_node_;
|
||||
}
|
||||
|
||||
const Expr* node() const {
|
||||
Expr* node() const {
|
||||
return base_expr_node_;
|
||||
}
|
||||
|
||||
@ -103,7 +103,7 @@ class TORCH_API ExprHandle {
|
||||
}
|
||||
|
||||
template <class Op>
|
||||
const Op* AsNode() const {
|
||||
Op* AsNode() const {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
return const_cast<ExprHandle*>(this)->AsNode<Op>();
|
||||
}
|
||||
@ -173,7 +173,7 @@ class TORCH_API Buf : public ExprNode<Buf> {
|
||||
static ExprHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
|
||||
|
||||
// TODO: unique_name
|
||||
const Var* base_handle() const {
|
||||
Var* base_handle() const {
|
||||
return base_handle_;
|
||||
}
|
||||
void set_base_handle(Var* base_handle) {
|
||||
@ -189,16 +189,16 @@ class TORCH_API Buf : public ExprNode<Buf> {
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Buf(const std::string& name_hint,
|
||||
const std::vector<const Expr*>& dims,
|
||||
const std::vector<Expr*>& dims,
|
||||
Dtype dtype,
|
||||
const Expr* initializer = nullptr)
|
||||
Expr* initializer = nullptr)
|
||||
: Buf(new Var(name_hint, kHandle), dims, dtype, initializer) {}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Buf(Var* var,
|
||||
std::vector<const Expr*> dims,
|
||||
std::vector<Expr*> dims,
|
||||
Dtype dtype,
|
||||
const Expr* initializer = nullptr)
|
||||
Expr* initializer = nullptr)
|
||||
: ExprNodeBase(dtype, kPrimitive),
|
||||
base_handle_(var),
|
||||
dims_(std::move(dims)),
|
||||
@ -209,20 +209,20 @@ class TORCH_API Buf : public ExprNode<Buf> {
|
||||
size_t ndim() const {
|
||||
return dims_.size();
|
||||
}
|
||||
const Expr* dim(size_t index) const {
|
||||
Expr* dim(size_t index) const {
|
||||
if (index >= ndim()) {
|
||||
throw out_of_range_index();
|
||||
}
|
||||
return dims_[index];
|
||||
}
|
||||
std::vector<const Expr*> dims() const {
|
||||
std::vector<Expr*> dims() const {
|
||||
return dims_;
|
||||
}
|
||||
void set_dims(std::vector<const Expr*> dims) {
|
||||
void set_dims(std::vector<Expr*> dims) {
|
||||
dims_ = dims;
|
||||
};
|
||||
|
||||
const Expr* initializer() const {
|
||||
Expr* initializer() const {
|
||||
return initializer_;
|
||||
};
|
||||
|
||||
@ -237,8 +237,8 @@ class TORCH_API Buf : public ExprNode<Buf> {
|
||||
|
||||
private:
|
||||
Var* base_handle_;
|
||||
std::vector<const Expr*> dims_;
|
||||
const Expr* initializer_;
|
||||
std::vector<Expr*> dims_;
|
||||
Expr* initializer_;
|
||||
};
|
||||
|
||||
class TORCH_API BufHandle : public ExprHandle {
|
||||
@ -254,9 +254,9 @@ class TORCH_API BufHandle : public ExprHandle {
|
||||
|
||||
explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}
|
||||
|
||||
explicit BufHandle(const Buf* node) : ExprHandle(node) {}
|
||||
const Buf* node() const {
|
||||
return static_cast<const Buf*>(ExprHandle::node());
|
||||
explicit BufHandle(Buf* node) : ExprHandle(node) {}
|
||||
Buf* node() const {
|
||||
return static_cast<Buf*>(ExprHandle::node());
|
||||
}
|
||||
Buf* node() {
|
||||
return static_cast<Buf*>(ExprHandle::node());
|
||||
@ -303,9 +303,9 @@ class TORCH_API VarHandle : public ExprHandle {
|
||||
explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {}
|
||||
VarHandle(const std::string& name_hint, Dtype dtype)
|
||||
: ExprHandle(Var::make(name_hint, dtype)) {}
|
||||
explicit VarHandle(const Var* node) : ExprHandle(node) {}
|
||||
const Var* node() const {
|
||||
return static_cast<const Var*>(ExprHandle::node());
|
||||
explicit VarHandle(Var* node) : ExprHandle(node) {}
|
||||
Var* node() const {
|
||||
return static_cast<Var*>(ExprHandle::node());
|
||||
}
|
||||
bool operator==(const VarHandle& other) const {
|
||||
return this->node() == other.node();
|
||||
@ -323,7 +323,7 @@ class TORCH_API VarHandle : public ExprHandle {
|
||||
};
|
||||
|
||||
template <class Op, class Base>
|
||||
const Expr* ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) const {
|
||||
Expr* ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
ExprNode* this_mutable = const_cast<ExprNode*>(this);
|
||||
return mutator->mutate(static_cast<Op*>(this_mutable));
|
||||
|
@ -21,18 +21,18 @@ std::vector<at::Tensor> constructTensors(
|
||||
std::vector<std::vector<int64_t>> buf_dims_vec;
|
||||
std::vector<c10::ScalarType> buf_dtypes_vec;
|
||||
int64_t buf_dims_idx = 0;
|
||||
for (const auto i : c10::irange(bufs_num)) {
|
||||
for (auto i : c10::irange(bufs_num)) {
|
||||
buf_data_vec.push_back(buf_data[i]);
|
||||
buf_dims_vec.emplace_back();
|
||||
// NOLINTNEXTLINE(clang-diagnostic-unused-variable,clang-analyzer-deadcode.DeadStores)
|
||||
for (const auto dim : c10::irange(buf_ranks[i])) {
|
||||
for (auto dim : c10::irange(buf_ranks[i])) {
|
||||
buf_dims_vec[i].push_back(buf_dims[buf_dims_idx++]);
|
||||
}
|
||||
buf_dtypes_vec.push_back(static_cast<c10::ScalarType>(buf_dtypes[i]));
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> tensors;
|
||||
for (const auto i : c10::irange(buf_data_vec.size())) {
|
||||
for (auto i : c10::irange(buf_data_vec.size())) {
|
||||
auto options = at::TensorOptions()
|
||||
.dtype(buf_dtypes_vec[i])
|
||||
.layout(at::kStrided)
|
||||
@ -208,7 +208,7 @@ void nnc_prepacked_linear_clamp_run(
|
||||
constructTensors(bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
|
||||
const at::Tensor& x = tensors[1];
|
||||
const auto context = reinterpret_cast<LinearOpContext*>(buf_data[2]);
|
||||
auto context = reinterpret_cast<LinearOpContext*>(buf_data[2]);
|
||||
at::Tensor output = context->run(x);
|
||||
memcpy(
|
||||
buf_data[0], output.data_ptr(), output.element_size() * output.numel());
|
||||
@ -228,7 +228,7 @@ void nnc_prepacked_conv2d_clamp_run(
|
||||
constructTensors(bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_dtypes);
|
||||
|
||||
const at::Tensor& x = tensors[1];
|
||||
const auto context = reinterpret_cast<Conv2dOpContext*>(buf_data[2]);
|
||||
auto context = reinterpret_cast<Conv2dOpContext*>(buf_data[2]);
|
||||
at::Tensor output = context->run(x);
|
||||
memcpy(
|
||||
buf_data[0], output.data_ptr(), output.element_size() * output.numel());
|
||||
|
@ -22,12 +22,12 @@ class HalfChecker : public IRVisitor {
|
||||
return hasHalf_;
|
||||
}
|
||||
|
||||
void visit(const Load* v) override {
|
||||
void visit(Load* v) override {
|
||||
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void visit(const Store* v) override {
|
||||
void visit(Store* v) override {
|
||||
hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half;
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
@ -36,7 +36,7 @@ class HalfChecker : public IRVisitor {
|
||||
hasHalf_ = true;
|
||||
}
|
||||
|
||||
void visit(const Cast* v) override {
|
||||
void visit(Cast* v) override {
|
||||
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
@ -47,21 +47,21 @@ class HalfChecker : public IRVisitor {
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
class HalfRewriter : public IRMutator {
|
||||
const Expr* mutate(const Load* v) override {
|
||||
const Expr* child = IRMutator::mutate(v);
|
||||
Expr* mutate(Load* v) override {
|
||||
Expr* child = IRMutator::mutate(v);
|
||||
if (child->dtype().scalar_type() != ScalarType::Half) {
|
||||
return child;
|
||||
}
|
||||
|
||||
const Expr* ret =
|
||||
Expr* ret =
|
||||
new Cast(child->dtype().cloneWithScalarType(ScalarType::Float), child);
|
||||
|
||||
inserted_half_casts_.insert(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
Stmt* mutate(const Store* v) override {
|
||||
const Expr* new_val = v->value()->accept_mutator(this);
|
||||
Stmt* mutate(Store* v) override {
|
||||
Expr* new_val = v->value()->accept_mutator(this);
|
||||
|
||||
Dtype newType = v->value()->dtype();
|
||||
if (newType.scalar_type() == ScalarType::Half) {
|
||||
@ -73,12 +73,12 @@ class HalfRewriter : public IRMutator {
|
||||
return new Store(v->buf(), v->indices(), new_val);
|
||||
}
|
||||
|
||||
const Expr* mutate(const HalfImm* v) override {
|
||||
Expr* mutate(HalfImm* v) override {
|
||||
return new Cast(kFloat, v);
|
||||
}
|
||||
|
||||
const Expr* mutate(const Cast* v) override {
|
||||
const Expr* child = v->src_value()->accept_mutator(this);
|
||||
Expr* mutate(Cast* v) override {
|
||||
Expr* child = v->src_value()->accept_mutator(this);
|
||||
|
||||
// just don't allow half casts we didn't insert.
|
||||
if (v->dtype().scalar_type() == ScalarType::Half) {
|
||||
@ -88,7 +88,7 @@ class HalfRewriter : public IRMutator {
|
||||
}
|
||||
|
||||
// Remove Half(Float()) and friends.
|
||||
const Cast* cast_child = dynamic_cast<const Cast*>(child);
|
||||
Cast* cast_child = dynamic_cast<Cast*>(child);
|
||||
if (cast_child) {
|
||||
if (v->dtype().is_floating_point() &&
|
||||
cast_child->dtype().is_floating_point()) {
|
||||
@ -102,10 +102,10 @@ class HalfRewriter : public IRMutator {
|
||||
|
||||
return new Cast(v->dtype(), child);
|
||||
}
|
||||
Stmt* mutate(const Let* v) override {
|
||||
Stmt* mutate(Let* v) override {
|
||||
if (v->dtype().scalar_type() == ScalarType::Half) {
|
||||
const Var* load_new_var = new Var(v->var()->name_hint(), kFloat);
|
||||
const Expr* new_value = new Cast(
|
||||
Var* load_new_var = new Var(v->var()->name_hint(), kFloat);
|
||||
Expr* new_value = new Cast(
|
||||
v->dtype().cloneWithScalarType(ScalarType::Float),
|
||||
v->value()->accept_mutator(this));
|
||||
var_map[v->var()] = load_new_var;
|
||||
@ -116,7 +116,7 @@ class HalfRewriter : public IRMutator {
|
||||
return IRMutator::mutate(v);
|
||||
}
|
||||
|
||||
const Expr* mutate(const Var* v) override {
|
||||
Expr* mutate(Var* v) override {
|
||||
auto it = var_map.find(v);
|
||||
if (it != var_map.end()) {
|
||||
return it->second;
|
||||
@ -126,8 +126,8 @@ class HalfRewriter : public IRMutator {
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_set<const Expr*> inserted_half_casts_;
|
||||
std::unordered_map<const Var*, const Var*> var_map;
|
||||
std::unordered_set<Expr*> inserted_half_casts_;
|
||||
std::unordered_map<Var*, Var*> var_map;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -28,91 +28,91 @@ bool SimplifierHashType::operator!=(const size_t other) const {
|
||||
return _h != other;
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Add* v) {
|
||||
void HashProvider::visit(Add* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "+", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Sub* v) {
|
||||
void HashProvider::visit(Sub* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "-", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Mul* v) {
|
||||
void HashProvider::visit(Mul* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "*", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Div* v) {
|
||||
void HashProvider::visit(Div* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "/", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Mod* v) {
|
||||
void HashProvider::visit(Mod* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Max* v) {
|
||||
void HashProvider::visit(Max* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "Mx", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Min* v) {
|
||||
void HashProvider::visit(Min* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "Mn", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const And* v) {
|
||||
void HashProvider::visit(And* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "&", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Or* v) {
|
||||
void HashProvider::visit(Or* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "|", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Xor* v) {
|
||||
void HashProvider::visit(Xor* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "^", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Lshift* v) {
|
||||
void HashProvider::visit(Lshift* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), "<<", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Rshift* v) {
|
||||
void HashProvider::visit(Rshift* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
putHash(v, hash_combine(hashOf(v->lhs()), ">>", hashOf(v->rhs())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const CompareSelect* v) {
|
||||
void HashProvider::visit(CompareSelect* v) {
|
||||
CACHE_GUARD();
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
@ -128,18 +128,18 @@ void HashProvider::visit(const CompareSelect* v) {
|
||||
hashOf(v->ret_val2())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Cast* v) {
|
||||
void HashProvider::visit(Cast* v) {
|
||||
CACHE_GUARD();
|
||||
v->src_value()->accept(this);
|
||||
putHash(v, hash_combine("cast", v->dtype(), hashOf(v->src_value())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Var* v) {
|
||||
void HashProvider::visit(Var* v) {
|
||||
CACHE_GUARD();
|
||||
putHash(v, hash_combine("var", name_manager_.get_unique_name(v)));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Ramp* v) {
|
||||
void HashProvider::visit(Ramp* v) {
|
||||
CACHE_GUARD();
|
||||
v->base()->accept(this);
|
||||
v->stride()->accept(this);
|
||||
@ -148,22 +148,22 @@ void HashProvider::visit(const Ramp* v) {
|
||||
hash_combine("ramp", hashOf(v->base()), hashOf(v->stride()), v->lanes()));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Load* v) {
|
||||
void HashProvider::visit(Load* v) {
|
||||
CACHE_GUARD();
|
||||
v->base_handle()->accept(this);
|
||||
SimplifierHashType indices_hash;
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
ind->accept(this);
|
||||
indices_hash = hash_combine(indices_hash, hashOf(ind));
|
||||
}
|
||||
putHash(v, hash_combine("load", hashOf(v->base_handle()), indices_hash));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Store* v) {
|
||||
void HashProvider::visit(Store* v) {
|
||||
CACHE_GUARD();
|
||||
v->base_handle()->accept(this);
|
||||
SimplifierHashType indices_hash;
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
ind->accept(this);
|
||||
indices_hash = hash_combine(indices_hash, hashOf(ind));
|
||||
}
|
||||
@ -174,7 +174,7 @@ void HashProvider::visit(const Store* v) {
|
||||
"store", hashOf(v->base_handle()), indices_hash, hashOf(v->value())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Block* v) {
|
||||
void HashProvider::visit(Block* v) {
|
||||
CACHE_GUARD();
|
||||
SimplifierHashType hash;
|
||||
|
||||
@ -185,7 +185,7 @@ void HashProvider::visit(const Block* v) {
|
||||
putHash(v, hash);
|
||||
}
|
||||
|
||||
void HashProvider::visit(const For* v) {
|
||||
void HashProvider::visit(For* v) {
|
||||
CACHE_GUARD();
|
||||
v->var()->accept(this);
|
||||
v->start()->accept(this);
|
||||
@ -202,13 +202,13 @@ void HashProvider::visit(const For* v) {
|
||||
putHash(v, hash);
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Broadcast* v) {
|
||||
void HashProvider::visit(Broadcast* v) {
|
||||
CACHE_GUARD();
|
||||
v->value()->accept(this);
|
||||
putHash(v, hash_combine("broadcast", hashOf(v->value()), v->lanes()));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const IfThenElse* v) {
|
||||
void HashProvider::visit(IfThenElse* v) {
|
||||
CACHE_GUARD();
|
||||
v->condition()->accept(this);
|
||||
v->true_value()->accept(this);
|
||||
@ -223,7 +223,7 @@ void HashProvider::visit(const IfThenElse* v) {
|
||||
hashOf(v->false_value())));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Intrinsics* v) {
|
||||
void HashProvider::visit(Intrinsics* v) {
|
||||
CACHE_GUARD();
|
||||
// calls to rand are not symbolic and have a different value each time, they
|
||||
// should not hash to anything and this is the best we can do.
|
||||
@ -234,7 +234,7 @@ void HashProvider::visit(const Intrinsics* v) {
|
||||
}
|
||||
|
||||
SimplifierHashType hash(te_hash(v->func_name()));
|
||||
for (const auto i : c10::irange(v->nparams())) {
|
||||
for (auto i : c10::irange(v->nparams())) {
|
||||
v->param(i)->accept(this);
|
||||
hash = hash_combine(hash, hashOf(v->param(i)));
|
||||
}
|
||||
@ -242,33 +242,33 @@ void HashProvider::visit(const Intrinsics* v) {
|
||||
putHash(v, hash);
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Allocate* v) {
|
||||
void HashProvider::visit(Allocate* v) {
|
||||
CACHE_GUARD();
|
||||
const Var* buffer_var = v->buffer_var();
|
||||
Var* buffer_var = v->buffer_var();
|
||||
buffer_var->accept(this);
|
||||
|
||||
SimplifierHashType hash =
|
||||
hash_combine("allocate", hashOf(buffer_var), v->dtype());
|
||||
|
||||
std::vector<const Expr*> dims = v->dims();
|
||||
for (const Expr* dim : dims) {
|
||||
std::vector<Expr*> dims = v->dims();
|
||||
for (Expr* dim : dims) {
|
||||
dim->accept(this);
|
||||
hash = hash_combine(hash, hashOf(dim));
|
||||
}
|
||||
putHash(v, hash);
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Free* v) {
|
||||
void HashProvider::visit(Free* v) {
|
||||
CACHE_GUARD();
|
||||
const Var* buffer_var = v->buffer_var();
|
||||
Var* buffer_var = v->buffer_var();
|
||||
buffer_var->accept(this);
|
||||
|
||||
putHash(v, hash_combine("free", hashOf(buffer_var)));
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Cond* v) {
|
||||
void HashProvider::visit(Cond* v) {
|
||||
CACHE_GUARD();
|
||||
const Expr* condition = v->condition();
|
||||
Expr* condition = v->condition();
|
||||
Stmt* true_stmt = v->true_stmt();
|
||||
Stmt* false_stmt = v->false_stmt();
|
||||
condition->accept(this);
|
||||
@ -286,7 +286,7 @@ void HashProvider::visit(const Cond* v) {
|
||||
putHash(v, hash);
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Term* v) {
|
||||
void HashProvider::visit(Term* v) {
|
||||
CACHE_GUARD();
|
||||
v->scalar()->accept(this);
|
||||
|
||||
@ -299,7 +299,7 @@ void HashProvider::visit(const Term* v) {
|
||||
putHash(v, hash);
|
||||
}
|
||||
|
||||
void HashProvider::visit(const Polynomial* v) {
|
||||
void HashProvider::visit(Polynomial* v) {
|
||||
CACHE_GUARD();
|
||||
v->scalar()->accept(this);
|
||||
|
||||
@ -312,7 +312,7 @@ void HashProvider::visit(const Polynomial* v) {
|
||||
putHash(v, hash);
|
||||
}
|
||||
|
||||
void HashProvider::visit(const MaxTerm* v) {
|
||||
void HashProvider::visit(MaxTerm* v) {
|
||||
CACHE_GUARD();
|
||||
SimplifierHashType hash = hash_combine("maxterm");
|
||||
if (v->scalar()) {
|
||||
@ -328,7 +328,7 @@ void HashProvider::visit(const MaxTerm* v) {
|
||||
putHash(v, hash);
|
||||
}
|
||||
|
||||
void HashProvider::visit(const MinTerm* v) {
|
||||
void HashProvider::visit(MinTerm* v) {
|
||||
CACHE_GUARD();
|
||||
SimplifierHashType hash = hash_combine("minterm");
|
||||
if (v->scalar()) {
|
||||
|
@ -53,7 +53,7 @@ class Polynomial;
|
||||
class TORCH_API HashProvider : public IRVisitor {
|
||||
public:
|
||||
template <class T>
|
||||
SimplifierHashType hash(const T* e) {
|
||||
SimplifierHashType hash(T* e) {
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
e->accept(this);
|
||||
return hashOf(e);
|
||||
@ -67,19 +67,19 @@ class TORCH_API HashProvider : public IRVisitor {
|
||||
exprToHash_.clear();
|
||||
}
|
||||
|
||||
void visit(const Add* v) override;
|
||||
void visit(const Sub* v) override;
|
||||
void visit(const Mul* v) override;
|
||||
void visit(const Div* v) override;
|
||||
void visit(const Mod* v) override;
|
||||
void visit(const Max* v) override;
|
||||
void visit(const Min* v) override;
|
||||
void visit(const And* v) override;
|
||||
void visit(const Or* v) override;
|
||||
void visit(const Xor* v) override;
|
||||
void visit(const Lshift* v) override;
|
||||
void visit(const Rshift* v) override;
|
||||
void visit(const CompareSelect* v) override;
|
||||
void visit(Add* v) override;
|
||||
void visit(Sub* v) override;
|
||||
void visit(Mul* v) override;
|
||||
void visit(Div* v) override;
|
||||
void visit(Mod* v) override;
|
||||
void visit(Max* v) override;
|
||||
void visit(Min* v) override;
|
||||
void visit(And* v) override;
|
||||
void visit(Or* v) override;
|
||||
void visit(Xor* v) override;
|
||||
void visit(Lshift* v) override;
|
||||
void visit(Rshift* v) override;
|
||||
void visit(CompareSelect* v) override;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define IMM_VISIT(Type, Name) \
|
||||
@ -90,23 +90,23 @@ class TORCH_API HashProvider : public IRVisitor {
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
|
||||
#undef IMM_VISIT
|
||||
|
||||
void visit(const Cast* v) override;
|
||||
void visit(const Var* v) override;
|
||||
void visit(const Ramp* v) override;
|
||||
void visit(const Load* v) override;
|
||||
void visit(const Store* v) override;
|
||||
void visit(const Block* v) override;
|
||||
void visit(const For* v) override;
|
||||
void visit(const Broadcast* v) override;
|
||||
void visit(const IfThenElse* v) override;
|
||||
void visit(const Intrinsics* v) override;
|
||||
void visit(const Allocate* v) override;
|
||||
void visit(const Free* v) override;
|
||||
void visit(const Cond* v) override;
|
||||
void visit(const Term* v) override;
|
||||
void visit(const Polynomial* v) override;
|
||||
void visit(const MaxTerm* v) override;
|
||||
void visit(const MinTerm* v) override;
|
||||
void visit(Cast* v) override;
|
||||
void visit(Var* v) override;
|
||||
void visit(Ramp* v) override;
|
||||
void visit(Load* v) override;
|
||||
void visit(Store* v) override;
|
||||
void visit(Block* v) override;
|
||||
void visit(For* v) override;
|
||||
void visit(Broadcast* v) override;
|
||||
void visit(IfThenElse* v) override;
|
||||
void visit(Intrinsics* v) override;
|
||||
void visit(Allocate* v) override;
|
||||
void visit(Free* v) override;
|
||||
void visit(Cond* v) override;
|
||||
void visit(Term* v) override;
|
||||
void visit(Polynomial* v) override;
|
||||
void visit(MaxTerm* v) override;
|
||||
void visit(MinTerm* v) override;
|
||||
|
||||
template <typename... Types>
|
||||
SimplifierHashType hash_combine(const Types&... args) {
|
||||
@ -116,7 +116,7 @@ class TORCH_API HashProvider : public IRVisitor {
|
||||
}
|
||||
|
||||
private:
|
||||
SimplifierHashType hashOf(const Expr* e) {
|
||||
SimplifierHashType hashOf(Expr* e) {
|
||||
auto it = exprToHash_.find(e);
|
||||
if (it != exprToHash_.end()) {
|
||||
return it->second;
|
||||
@ -132,7 +132,7 @@ class TORCH_API HashProvider : public IRVisitor {
|
||||
return hash;
|
||||
}
|
||||
|
||||
SimplifierHashType hashOf(const Stmt* s) {
|
||||
SimplifierHashType hashOf(Stmt* s) {
|
||||
auto it = exprToHash_.find(s);
|
||||
if (it != exprToHash_.end()) {
|
||||
return it->second;
|
||||
@ -169,7 +169,7 @@ class TORCH_API HashProvider : public IRVisitor {
|
||||
(seed._h >> 4);
|
||||
}
|
||||
|
||||
void _hash_combine(SimplifierHashType& seed, const Expr* e) {
|
||||
void _hash_combine(SimplifierHashType& seed, Expr* e) {
|
||||
_hash_combine(seed, hash(e));
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ static Dtype ChooseDtype(const Dtype& buffer_dtype, const Dtype& index_dtype) {
|
||||
return Dtype(buffer_dtype, index_dtype.lanes());
|
||||
}
|
||||
|
||||
static Dtype dtypeOfIndices(const std::vector<const Expr*>& indices) {
|
||||
static Dtype dtypeOfIndices(const std::vector<Expr*>& indices) {
|
||||
if (!indices.size()) {
|
||||
// Return something so we can handle scalar buffers.
|
||||
return kInt;
|
||||
@ -20,7 +20,7 @@ static Dtype dtypeOfIndices(const std::vector<const Expr*>& indices) {
|
||||
return indices.at(0)->dtype();
|
||||
}
|
||||
|
||||
void castIndicesToInts(std::vector<const Expr*>& indices) {
|
||||
void castIndicesToInts(std::vector<Expr*>& indices) {
|
||||
// Cast all indices to either Int or Long
|
||||
auto index_dtype = ScalarType::Int;
|
||||
for (auto& index : indices) {
|
||||
@ -40,12 +40,12 @@ void castIndicesToInts(std::vector<const Expr*>& indices) {
|
||||
}
|
||||
}
|
||||
|
||||
Load::Load(Dtype dtype, const Buf* buf, std::vector<const Expr*> indices)
|
||||
Load::Load(Dtype dtype, Buf* buf, std::vector<Expr*> indices)
|
||||
: ExprNodeBase(dtype), buf_(buf), indices_(std::move(indices)) {
|
||||
castIndicesToInts(indices_);
|
||||
}
|
||||
|
||||
Load::Load(const Buf* buf, const std::vector<const Expr*>& indices)
|
||||
Load::Load(Buf* buf, const std::vector<Expr*>& indices)
|
||||
: Load(ChooseDtype(buf->dtype(), dtypeOfIndices(indices)), buf, indices) {}
|
||||
|
||||
ExprHandle Load::make(
|
||||
@ -62,10 +62,7 @@ ExprHandle Load::make(
|
||||
return Load::make(buf.dtype(), buf, indices);
|
||||
}
|
||||
|
||||
Store::Store(
|
||||
const Buf* buf,
|
||||
std::vector<const Expr*> indices,
|
||||
const Expr* value)
|
||||
Store::Store(Buf* buf, std::vector<Expr*> indices, Expr* value)
|
||||
: buf_(buf), indices_(std::move(indices)), value_(value) {
|
||||
castIndicesToInts(indices_);
|
||||
}
|
||||
@ -78,9 +75,9 @@ Store* Store::make(
|
||||
buf.node(), ExprHandleVectorToExprVector(indices), value.node());
|
||||
}
|
||||
|
||||
const Expr* flatten_index(
|
||||
const std::vector<const Expr*>& dims,
|
||||
const std::vector<const Expr*>& indices) {
|
||||
Expr* flatten_index(
|
||||
const std::vector<Expr*>& dims,
|
||||
const std::vector<Expr*>& indices) {
|
||||
// Handle already flattened indices first
|
||||
if (indices.size() == 1) {
|
||||
return indices[0];
|
||||
@ -93,7 +90,7 @@ const Expr* flatten_index(
|
||||
if (ndim == 0) {
|
||||
return new IntImm(0);
|
||||
}
|
||||
std::vector<const Expr*> strides(ndim);
|
||||
std::vector<Expr*> strides(ndim);
|
||||
// stride[i] = stride[i+1]*dims[i+1], i < ndim-1
|
||||
// stride[i] = 1, i = ndim-1
|
||||
strides[ndim - 1] = new IntImm(1);
|
||||
@ -101,8 +98,8 @@ const Expr* flatten_index(
|
||||
strides[ndim - 1 - i] = new Mul(strides[ndim - i], dims[ndim - i]);
|
||||
}
|
||||
|
||||
const Expr* total_index = new IntImm(0);
|
||||
for (const auto i : c10::irange(ndim)) {
|
||||
Expr* total_index = new IntImm(0);
|
||||
for (auto i : c10::irange(ndim)) {
|
||||
total_index = new Add(total_index, new Mul(indices[i], strides[i]));
|
||||
}
|
||||
return total_index;
|
||||
@ -123,7 +120,7 @@ Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2) {
|
||||
|
||||
Dtype Intrinsics::IntrinsicsDtype(
|
||||
IntrinsicsOp op_type,
|
||||
const std::vector<const Expr*>& params) {
|
||||
const std::vector<Expr*>& params) {
|
||||
// TODO: check the op_type and make a real decision
|
||||
// Doesnt this fail with kRand?
|
||||
if (params.size() == 0) {
|
||||
@ -184,7 +181,7 @@ ExternalCall* ExternalCall::make(
|
||||
const std::string& func_name,
|
||||
const std::vector<BufHandle>& buf_args,
|
||||
const std::vector<ExprHandle>& args) {
|
||||
std::vector<const Buf*> buf_arg_nodes;
|
||||
std::vector<Buf*> buf_arg_nodes;
|
||||
buf_arg_nodes.reserve(buf_args.size());
|
||||
for (const BufHandle& buf_arg : buf_args) {
|
||||
buf_arg_nodes.push_back(buf_arg.node());
|
||||
@ -193,37 +190,35 @@ ExternalCall* ExternalCall::make(
|
||||
buf.node(), func_name, buf_arg_nodes, ExprHandleVectorToExprVector(args));
|
||||
}
|
||||
|
||||
std::vector<const Expr*> ExprHandleVectorToExprVector(
|
||||
std::vector<Expr*> ExprHandleVectorToExprVector(
|
||||
const std::vector<ExprHandle>& v) {
|
||||
std::vector<const Expr*> result(v.size());
|
||||
for (const auto i : c10::irange(v.size())) {
|
||||
std::vector<Expr*> result(v.size());
|
||||
for (auto i : c10::irange(v.size())) {
|
||||
result[i] = v[i].node();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<ExprHandle> ExprVectorToExprHandleVector(
|
||||
const std::vector<const Expr*>& v) {
|
||||
const std::vector<Expr*>& v) {
|
||||
std::vector<ExprHandle> result(v.size());
|
||||
for (const auto i : c10::irange(v.size())) {
|
||||
for (auto i : c10::irange(v.size())) {
|
||||
result[i] = ExprHandle(v[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<const Var*> VarHandleVectorToVarVector(
|
||||
const std::vector<VarHandle>& v) {
|
||||
std::vector<const Var*> result(v.size());
|
||||
for (const auto i : c10::irange(v.size())) {
|
||||
std::vector<Var*> VarHandleVectorToVarVector(const std::vector<VarHandle>& v) {
|
||||
std::vector<Var*> result(v.size());
|
||||
for (auto i : c10::irange(v.size())) {
|
||||
result[i] = v[i].node();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<VarHandle> VarVectorToVarHandleVector(
|
||||
const std::vector<const Var*>& v) {
|
||||
std::vector<VarHandle> VarVectorToVarHandleVector(const std::vector<Var*>& v) {
|
||||
std::vector<VarHandle> result(v.size());
|
||||
for (const auto i : c10::irange(v.size())) {
|
||||
for (auto i : c10::irange(v.size())) {
|
||||
result[i] = VarHandle(v[i]);
|
||||
}
|
||||
return result;
|
||||
|
@ -68,13 +68,13 @@ class Placeholder;
|
||||
|
||||
class TORCH_API Cast : public ExprNode<Cast> {
|
||||
public:
|
||||
const Expr* src_value() const {
|
||||
Expr* src_value() const {
|
||||
return src_value_;
|
||||
}
|
||||
static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
|
||||
return ExprHandle(new Cast(dtype, src_value.node()));
|
||||
}
|
||||
Cast(Dtype dtype, const Expr* src_value)
|
||||
Cast(Dtype dtype, Expr* src_value)
|
||||
: ExprNodeBase(dtype, kCast), src_value_(src_value) {}
|
||||
|
||||
bool isConstant() const override {
|
||||
@ -82,7 +82,7 @@ class TORCH_API Cast : public ExprNode<Cast> {
|
||||
}
|
||||
|
||||
private:
|
||||
const Expr* src_value_;
|
||||
Expr* src_value_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -93,13 +93,13 @@ ExprHandle cast(const ExprHandle& src_value) {
|
||||
// This is a bitwise cast, akin to bitcast in LLVM
|
||||
class TORCH_API BitCast : public ExprNode<BitCast> {
|
||||
public:
|
||||
const Expr* src_value() const {
|
||||
Expr* src_value() const {
|
||||
return src_value_;
|
||||
}
|
||||
static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
|
||||
return ExprHandle(new BitCast(dtype, src_value.node()));
|
||||
}
|
||||
BitCast(Dtype dtype, const Expr* src_value)
|
||||
BitCast(Dtype dtype, Expr* src_value)
|
||||
: ExprNodeBase(dtype, kBitCast), src_value_(src_value) {
|
||||
TORCH_CHECK(src_value_->dtype().byte_size() == dtype.byte_size());
|
||||
}
|
||||
@ -109,7 +109,7 @@ class TORCH_API BitCast : public ExprNode<BitCast> {
|
||||
}
|
||||
|
||||
private:
|
||||
const Expr* src_value_;
|
||||
Expr* src_value_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -123,10 +123,10 @@ ExprHandle bitcast(const ExprHandle& src_value) {
|
||||
template <typename Op>
|
||||
class BinaryOpNode : public ExprNode<Op> {
|
||||
public:
|
||||
const Expr* lhs() const {
|
||||
Expr* lhs() const {
|
||||
return this->lhs_;
|
||||
}
|
||||
const Expr* rhs() const {
|
||||
Expr* rhs() const {
|
||||
return this->rhs_;
|
||||
}
|
||||
|
||||
@ -136,8 +136,8 @@ class BinaryOpNode : public ExprNode<Op> {
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
BinaryOpNode(
|
||||
const Expr* lhs_v,
|
||||
const Expr* rhs_v,
|
||||
Expr* lhs_v,
|
||||
Expr* rhs_v,
|
||||
IRNodeType expr_type,
|
||||
ScalarType ret_type = ScalarType::Undefined)
|
||||
: ExprNode<Op>(
|
||||
@ -148,51 +148,46 @@ class BinaryOpNode : public ExprNode<Op> {
|
||||
rhs_(CastIfNeeded(rhs_v, ExprNode<Op>::dtype())) {}
|
||||
|
||||
private:
|
||||
static const Expr* CastIfNeeded(const Expr* expr, Dtype dst_dtype) {
|
||||
static Expr* CastIfNeeded(Expr* expr, Dtype dst_dtype) {
|
||||
if (expr->dtype() == dst_dtype) {
|
||||
return expr;
|
||||
}
|
||||
return Cast::make(dst_dtype, ExprHandle(expr)).node();
|
||||
}
|
||||
|
||||
const Expr* lhs_;
|
||||
const Expr* rhs_;
|
||||
Expr* lhs_;
|
||||
Expr* rhs_;
|
||||
};
|
||||
|
||||
class TORCH_API Add : public BinaryOpNode<Add> {
|
||||
public:
|
||||
Add(const Expr* lhs, const Expr* rhs)
|
||||
: BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
|
||||
Add(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
|
||||
};
|
||||
|
||||
class TORCH_API Sub : public BinaryOpNode<Sub> {
|
||||
public:
|
||||
Sub(const Expr* lhs, const Expr* rhs)
|
||||
: BinaryOpNode(lhs, rhs, IRNodeType::kSub) {}
|
||||
Sub(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kSub) {}
|
||||
};
|
||||
|
||||
class TORCH_API Mul : public BinaryOpNode<Mul> {
|
||||
public:
|
||||
Mul(const Expr* lhs, const Expr* rhs)
|
||||
: BinaryOpNode(lhs, rhs, IRNodeType::kMul) {}
|
||||
Mul(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMul) {}
|
||||
};
|
||||
|
||||
class TORCH_API Div : public BinaryOpNode<Div> {
|
||||
public:
|
||||
Div(const Expr* lhs, const Expr* rhs)
|
||||
: BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {}
|
||||
Div(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {}
|
||||
};
|
||||
|
||||
class TORCH_API Mod : public BinaryOpNode<Mod> {
|
||||
public:
|
||||
Mod(const Expr* lhs, const Expr* rhs)
|
||||
: BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
|
||||
Mod(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
class BitwiseOpNode : public BinaryOpNode<Op> {
|
||||
public:
|
||||
BitwiseOpNode(const Expr* lhs, const Expr* rhs, IRNodeType type)
|
||||
BitwiseOpNode(Expr* lhs, Expr* rhs, IRNodeType type)
|
||||
: BinaryOpNode<Op>(lhs, rhs, type) {}
|
||||
|
||||
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
|
||||
@ -208,32 +203,27 @@ class BitwiseOpNode : public BinaryOpNode<Op> {
|
||||
|
||||
class TORCH_API And : public BitwiseOpNode<And> {
|
||||
public:
|
||||
And(const Expr* lhs, const Expr* rhs)
|
||||
: BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
|
||||
And(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kAnd) {}
|
||||
};
|
||||
|
||||
class TORCH_API Or : public BitwiseOpNode<Or> {
|
||||
public:
|
||||
Or(const Expr* lhs, const Expr* rhs)
|
||||
: BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
|
||||
Or(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kOr) {}
|
||||
};
|
||||
|
||||
class TORCH_API Xor : public BitwiseOpNode<Xor> {
|
||||
public:
|
||||
Xor(const Expr* lhs, const Expr* rhs)
|
||||
: BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
|
||||
Xor(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kXor) {}
|
||||
};
|
||||
|
||||
class TORCH_API Lshift : public BitwiseOpNode<Lshift> {
|
||||
public:
|
||||
Lshift(const Expr* lhs, const Expr* rhs)
|
||||
: BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
|
||||
Lshift(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kLshift) {}
|
||||
};
|
||||
|
||||
class TORCH_API Rshift : public BitwiseOpNode<Rshift> {
|
||||
public:
|
||||
Rshift(const Expr* lhs, const Expr* rhs)
|
||||
: BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
|
||||
Rshift(Expr* lhs, Expr* rhs) : BitwiseOpNode(lhs, rhs, IRNodeType::kRshift) {}
|
||||
};
|
||||
|
||||
// TODO: add TORCH_API
|
||||
@ -243,7 +233,7 @@ class Max : public BinaryOpNode<Max> {
|
||||
bool propagate_nans_;
|
||||
|
||||
public:
|
||||
Max(const Expr* lhs, const Expr* rhs, bool propagate_nans)
|
||||
Max(Expr* lhs, Expr* rhs, bool propagate_nans)
|
||||
: BinaryOpNode(lhs, rhs, IRNodeType::kMax),
|
||||
propagate_nans_(propagate_nans) {}
|
||||
|
||||
@ -267,7 +257,7 @@ class Min : public BinaryOpNode<Min> {
|
||||
bool propagate_nans_;
|
||||
|
||||
public:
|
||||
Min(const Expr* lhs, const Expr* rhs, bool propagate_nans)
|
||||
Min(Expr* lhs, Expr* rhs, bool propagate_nans)
|
||||
: BinaryOpNode(lhs, rhs, IRNodeType::kMin),
|
||||
propagate_nans_(propagate_nans) {}
|
||||
|
||||
@ -328,7 +318,7 @@ Expr* getImmediateByType(Dtype dtype, T initialVal) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T immediateAs(const Expr* e) {
|
||||
T immediateAs(Expr* e) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
|
||||
return imm->value(); \
|
||||
@ -345,7 +335,7 @@ T immediateAs(ExprHandle e) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool immediateEquals(const Expr* e, T val) {
|
||||
bool immediateEquals(Expr* e, T val) {
|
||||
#define TYPE_CASE(Type, Name) \
|
||||
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
|
||||
return imm->value() == val; \
|
||||
@ -371,10 +361,10 @@ bool immediateIsNegative(const T* e) {
|
||||
// [base, base + 1 * stride, ... , base + (lanes - 1) * stride]
|
||||
class TORCH_API Ramp : public ExprNode<Ramp> {
|
||||
public:
|
||||
const Expr* base() const {
|
||||
Expr* base() const {
|
||||
return base_;
|
||||
}
|
||||
const Expr* stride() const {
|
||||
Expr* stride() const {
|
||||
return stride_;
|
||||
}
|
||||
static ExprHandle make(
|
||||
@ -390,31 +380,31 @@ class TORCH_API Ramp : public ExprNode<Ramp> {
|
||||
return lanes_;
|
||||
}
|
||||
|
||||
Ramp(const Expr* base, const Expr* stride, int lanes)
|
||||
Ramp(Expr* base, Expr* stride, int lanes)
|
||||
: ExprNodeBase(Dtype(base->dtype(), lanes)),
|
||||
base_(base),
|
||||
stride_(stride),
|
||||
lanes_(lanes) {}
|
||||
|
||||
private:
|
||||
const Expr* base_;
|
||||
const Expr* stride_;
|
||||
Expr* base_;
|
||||
Expr* stride_;
|
||||
int lanes_;
|
||||
};
|
||||
|
||||
class TORCH_API Load : public ExprNode<Load> {
|
||||
public:
|
||||
const Var* base_handle() const {
|
||||
Var* base_handle() const {
|
||||
return buf_->base_handle();
|
||||
}
|
||||
std::vector<const Expr*> indices() const {
|
||||
std::vector<Expr*> indices() const {
|
||||
return indices_;
|
||||
}
|
||||
const Expr* flat_index() const {
|
||||
Expr* flat_index() const {
|
||||
TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
|
||||
return indices_[0];
|
||||
}
|
||||
const Buf* buf() const {
|
||||
Buf* buf() const {
|
||||
return buf_;
|
||||
}
|
||||
static ExprHandle make(
|
||||
@ -425,21 +415,21 @@ class TORCH_API Load : public ExprNode<Load> {
|
||||
const BufHandle& buf,
|
||||
const std::vector<ExprHandle>& indices);
|
||||
|
||||
Load(Dtype dtype, const Buf* base_handle, std::vector<const Expr*> indices);
|
||||
Load(const Buf* base_handle, const std::vector<const Expr*>& indices);
|
||||
Load(Dtype dtype, Buf* base_handle, std::vector<Expr*> indices);
|
||||
Load(Buf* base_handle, const std::vector<Expr*>& indices);
|
||||
|
||||
void set_indices(std::vector<const Expr*> indices) {
|
||||
void set_indices(std::vector<Expr*> indices) {
|
||||
indices_ = indices;
|
||||
};
|
||||
|
||||
private:
|
||||
const Buf* buf_;
|
||||
std::vector<const Expr*> indices_;
|
||||
Buf* buf_;
|
||||
std::vector<Expr*> indices_;
|
||||
};
|
||||
|
||||
class TORCH_API Broadcast : public ExprNode<Broadcast> {
|
||||
public:
|
||||
const Expr* value() const {
|
||||
Expr* value() const {
|
||||
return value_;
|
||||
}
|
||||
int lanes() const {
|
||||
@ -448,29 +438,29 @@ class TORCH_API Broadcast : public ExprNode<Broadcast> {
|
||||
static ExprHandle make(const ExprHandle& value, int lanes) {
|
||||
return ExprHandle(new Broadcast(value.node(), lanes));
|
||||
}
|
||||
Broadcast(const Expr* value, int lanes)
|
||||
Broadcast(Expr* value, int lanes)
|
||||
: ExprNodeBase(Dtype(value->dtype(), lanes)),
|
||||
value_(value),
|
||||
lanes_(lanes) {}
|
||||
|
||||
private:
|
||||
const Expr* value_;
|
||||
Expr* value_;
|
||||
int lanes_;
|
||||
};
|
||||
|
||||
class TORCH_API IfThenElse : public ExprNode<IfThenElse> {
|
||||
public:
|
||||
const Expr* condition() const {
|
||||
Expr* condition() const {
|
||||
return condition_;
|
||||
}
|
||||
|
||||
// Lazily evaluated only if condition is true
|
||||
const Expr* true_value() const {
|
||||
Expr* true_value() const {
|
||||
return true_;
|
||||
}
|
||||
|
||||
// Lazily evaluated only if condition is false
|
||||
const Expr* false_value() const {
|
||||
Expr* false_value() const {
|
||||
return false_;
|
||||
}
|
||||
|
||||
@ -490,13 +480,13 @@ class TORCH_API IfThenElse : public ExprNode<IfThenElse> {
|
||||
return ExprHandle(new IfThenElse(c.node(), t.node(), f.node()));
|
||||
}
|
||||
|
||||
IfThenElse(const Expr* c, const Expr* t, const Expr* f)
|
||||
IfThenElse(Expr* c, Expr* t, Expr* f)
|
||||
: ExprNodeBase(t->dtype()), condition_(c), true_(t), false_(f) {}
|
||||
|
||||
private:
|
||||
const Expr* condition_;
|
||||
const Expr* true_;
|
||||
const Expr* false_;
|
||||
Expr* condition_;
|
||||
Expr* true_;
|
||||
Expr* false_;
|
||||
};
|
||||
|
||||
class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
|
||||
@ -504,16 +494,16 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
|
||||
CompareSelectOperation compare_select_op() const {
|
||||
return compare_op_;
|
||||
}
|
||||
const Expr* lhs() const {
|
||||
Expr* lhs() const {
|
||||
return this->lhs_;
|
||||
}
|
||||
const Expr* rhs() const {
|
||||
Expr* rhs() const {
|
||||
return this->rhs_;
|
||||
}
|
||||
const Expr* ret_val1() const {
|
||||
Expr* ret_val1() const {
|
||||
return this->ret_val1_;
|
||||
}
|
||||
const Expr* ret_val2() const {
|
||||
Expr* ret_val2() const {
|
||||
return this->ret_val2_;
|
||||
}
|
||||
CompareSelectBias bias() const {
|
||||
@ -557,10 +547,10 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
|
||||
}
|
||||
|
||||
CompareSelect(
|
||||
const Expr* lhs,
|
||||
const Expr* rhs,
|
||||
const Expr* ret_val1,
|
||||
const Expr* ret_val2,
|
||||
Expr* lhs,
|
||||
Expr* rhs,
|
||||
Expr* ret_val1,
|
||||
Expr* ret_val2,
|
||||
CompareSelectOperation cmp_op,
|
||||
CompareSelectBias bias = kUnbiased)
|
||||
: ExprNodeBase(ret_val1->dtype()),
|
||||
@ -573,8 +563,8 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
CompareSelect(
|
||||
const Expr* lhs,
|
||||
const Expr* rhs,
|
||||
Expr* lhs,
|
||||
Expr* rhs,
|
||||
CompareSelectOperation cmp_op,
|
||||
CompareSelectBias bias = kUnbiased)
|
||||
: ExprNodeBase(kInt),
|
||||
@ -586,10 +576,10 @@ class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
|
||||
bias_(bias) {}
|
||||
|
||||
private:
|
||||
const Expr* lhs_;
|
||||
const Expr* rhs_;
|
||||
const Expr* ret_val1_;
|
||||
const Expr* ret_val2_;
|
||||
Expr* lhs_;
|
||||
Expr* rhs_;
|
||||
Expr* ret_val1_;
|
||||
Expr* ret_val2_;
|
||||
CompareSelectOperation compare_op_;
|
||||
CompareSelectBias bias_;
|
||||
};
|
||||
@ -647,7 +637,7 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
|
||||
IntrinsicsOp op_type,
|
||||
const std::vector<ExprHandle>& params) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Expr*> params_nodes(params.size());
|
||||
std::vector<Expr*> params_nodes(params.size());
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
params_nodes[i] = params[i].node();
|
||||
}
|
||||
@ -747,7 +737,7 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Intrinsics(IntrinsicsOp op_type, const Expr* v1)
|
||||
Intrinsics(IntrinsicsOp op_type, Expr* v1)
|
||||
: ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype())),
|
||||
params_({v1}),
|
||||
op_type_(op_type) {
|
||||
@ -757,7 +747,7 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Intrinsics(IntrinsicsOp op_type, const Expr* v1, const Expr* v2)
|
||||
Intrinsics(IntrinsicsOp op_type, Expr* v1, Expr* v2)
|
||||
: ExprNodeBase(IntrinsicsDtype(op_type, v1->dtype(), v2->dtype())),
|
||||
params_({v1, v2}),
|
||||
op_type_(op_type) {
|
||||
@ -767,7 +757,7 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Intrinsics(IntrinsicsOp op_type, const std::vector<const Expr*>& params)
|
||||
Intrinsics(IntrinsicsOp op_type, const std::vector<Expr*>& params)
|
||||
: ExprNodeBase(IntrinsicsDtype(op_type, params)),
|
||||
params_(params),
|
||||
op_type_(op_type) {
|
||||
@ -784,10 +774,10 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
|
||||
return params_.size();
|
||||
}
|
||||
|
||||
const Expr* param(int index) const {
|
||||
Expr* param(int index) const {
|
||||
return params_[index];
|
||||
}
|
||||
const std::vector<const Expr*>& params() const {
|
||||
const std::vector<Expr*>& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
@ -797,9 +787,9 @@ class TORCH_API Intrinsics : public ExprNode<Intrinsics> {
|
||||
static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2);
|
||||
static Dtype IntrinsicsDtype(
|
||||
IntrinsicsOp op_type,
|
||||
const std::vector<const Expr*>& params);
|
||||
const std::vector<Expr*>& params);
|
||||
|
||||
std::vector<const Expr*> params_;
|
||||
std::vector<Expr*> params_;
|
||||
IntrinsicsOp op_type_;
|
||||
};
|
||||
|
||||
@ -808,17 +798,17 @@ class Term;
|
||||
class MaxTerm;
|
||||
class MinTerm;
|
||||
|
||||
TORCH_API std::vector<const Expr*> ExprHandleVectorToExprVector(
|
||||
TORCH_API std::vector<Expr*> ExprHandleVectorToExprVector(
|
||||
const std::vector<ExprHandle>&);
|
||||
TORCH_API std::vector<ExprHandle> ExprVectorToExprHandleVector(
|
||||
const std::vector<const Expr*>&);
|
||||
TORCH_API std::vector<const Var*> VarHandleVectorToVarVector(
|
||||
const std::vector<Expr*>&);
|
||||
TORCH_API std::vector<Var*> VarHandleVectorToVarVector(
|
||||
const std::vector<VarHandle>&);
|
||||
TORCH_API std::vector<VarHandle> VarVectorToVarHandleVector(
|
||||
const std::vector<const Var*>&);
|
||||
TORCH_API const Expr* flatten_index(
|
||||
const std::vector<const Expr*>& dims,
|
||||
const std::vector<const Expr*>& indices);
|
||||
const std::vector<Var*>&);
|
||||
TORCH_API Expr* flatten_index(
|
||||
const std::vector<Expr*>& dims,
|
||||
const std::vector<Expr*>& indices);
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
|
@ -12,14 +12,14 @@ namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
template <typename Op>
|
||||
static const Expr* mutate_binary_op(
|
||||
const BinaryOpNode<Op>* v,
|
||||
static Expr* mutate_binary_op(
|
||||
BinaryOpNode<Op>* v,
|
||||
IRMutator* mutator,
|
||||
bool option = false) {
|
||||
const Expr* lhs = v->lhs();
|
||||
const Expr* rhs = v->rhs();
|
||||
const Expr* lhs_new = lhs->accept_mutator(mutator);
|
||||
const Expr* rhs_new = rhs->accept_mutator(mutator);
|
||||
Expr* lhs = v->lhs();
|
||||
Expr* rhs = v->rhs();
|
||||
Expr* lhs_new = lhs->accept_mutator(mutator);
|
||||
Expr* rhs_new = rhs->accept_mutator(mutator);
|
||||
if (lhs == lhs_new && rhs == rhs_new) {
|
||||
return v;
|
||||
}
|
||||
@ -54,63 +54,63 @@ static const Expr* mutate_binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Add* v) {
|
||||
Expr* IRMutator::mutate(Add* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Sub* v) {
|
||||
Expr* IRMutator::mutate(Sub* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Mul* v) {
|
||||
Expr* IRMutator::mutate(Mul* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Div* v) {
|
||||
Expr* IRMutator::mutate(Div* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Mod* v) {
|
||||
Expr* IRMutator::mutate(Mod* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const And* v) {
|
||||
Expr* IRMutator::mutate(And* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Or* v) {
|
||||
Expr* IRMutator::mutate(Or* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Xor* v) {
|
||||
Expr* IRMutator::mutate(Xor* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Lshift* v) {
|
||||
Expr* IRMutator::mutate(Lshift* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Rshift* v) {
|
||||
Expr* IRMutator::mutate(Rshift* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Max* v) {
|
||||
Expr* IRMutator::mutate(Max* v) {
|
||||
return mutate_binary_op(v, this, v->propagate_nans());
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Min* v) {
|
||||
Expr* IRMutator::mutate(Min* v) {
|
||||
return mutate_binary_op(v, this, v->propagate_nans());
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const CompareSelect* v) {
|
||||
const Expr* lhs = v->lhs();
|
||||
const Expr* rhs = v->rhs();
|
||||
const Expr* retval1 = v->ret_val1();
|
||||
const Expr* retval2 = v->ret_val2();
|
||||
const Expr* lhs_new = lhs->accept_mutator(this);
|
||||
const Expr* rhs_new = rhs->accept_mutator(this);
|
||||
const Expr* retval1_new = retval1->accept_mutator(this);
|
||||
const Expr* retval2_new = retval2->accept_mutator(this);
|
||||
Expr* IRMutator::mutate(CompareSelect* v) {
|
||||
Expr* lhs = v->lhs();
|
||||
Expr* rhs = v->rhs();
|
||||
Expr* retval1 = v->ret_val1();
|
||||
Expr* retval2 = v->ret_val2();
|
||||
Expr* lhs_new = lhs->accept_mutator(this);
|
||||
Expr* rhs_new = rhs->accept_mutator(this);
|
||||
Expr* retval1_new = retval1->accept_mutator(this);
|
||||
Expr* retval2_new = retval2->accept_mutator(this);
|
||||
if (lhs == lhs_new && rhs == rhs_new && retval1 == retval1_new &&
|
||||
retval2 == retval2_new) {
|
||||
return v;
|
||||
@ -126,68 +126,68 @@ const Expr* IRMutator::mutate(const CompareSelect* v) {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define IMM_MUTATE_DEFINE(_1, Name) \
|
||||
const Expr* IRMutator::mutate(const Name##Imm* v) { \
|
||||
return v; \
|
||||
#define IMM_MUTATE_DEFINE(_1, Name) \
|
||||
Expr* IRMutator::mutate(Name##Imm* v) { \
|
||||
return v; \
|
||||
}
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE);
|
||||
#undef IMM_MUTATE_DEFINE
|
||||
|
||||
const Expr* IRMutator::mutate(const Cast* v) {
|
||||
const Expr* src_value = v->src_value();
|
||||
const Expr* src_value_new = src_value->accept_mutator(this);
|
||||
Expr* IRMutator::mutate(Cast* v) {
|
||||
Expr* src_value = v->src_value();
|
||||
Expr* src_value_new = src_value->accept_mutator(this);
|
||||
if (src_value_new == v->src_value()) {
|
||||
return v;
|
||||
}
|
||||
return new Cast(v->dtype(), src_value_new);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const BitCast* v) {
|
||||
const Expr* src_value = v->src_value();
|
||||
const Expr* src_value_new = src_value->accept_mutator(this);
|
||||
Expr* IRMutator::mutate(BitCast* v) {
|
||||
Expr* src_value = v->src_value();
|
||||
Expr* src_value_new = src_value->accept_mutator(this);
|
||||
if (src_value_new == v->src_value()) {
|
||||
return v;
|
||||
}
|
||||
return new BitCast(v->dtype(), src_value_new);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Var* v) {
|
||||
Expr* IRMutator::mutate(Var* v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Ramp* v) {
|
||||
const Expr* base = v->base();
|
||||
const Expr* stride = v->stride();
|
||||
const Expr* base_new = base->accept_mutator(this);
|
||||
const Expr* stride_new = stride->accept_mutator(this);
|
||||
Expr* IRMutator::mutate(Ramp* v) {
|
||||
Expr* base = v->base();
|
||||
Expr* stride = v->stride();
|
||||
Expr* base_new = base->accept_mutator(this);
|
||||
Expr* stride_new = stride->accept_mutator(this);
|
||||
if (base == base_new && stride == stride_new) {
|
||||
return v;
|
||||
}
|
||||
return new Ramp(base_new, stride_new, v->lanes());
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Load* v) {
|
||||
Expr* IRMutator::mutate(Load* v) {
|
||||
Dtype dtype = v->dtype();
|
||||
const Buf* buf = v->buf();
|
||||
Buf* buf = v->buf();
|
||||
|
||||
bool any_index_changed = false;
|
||||
std::vector<const Expr*> indices_new;
|
||||
for (const Expr* ind : v->indices()) {
|
||||
const Expr* new_ind = ind->accept_mutator(this);
|
||||
std::vector<Expr*> indices_new;
|
||||
for (Expr* ind : v->indices()) {
|
||||
Expr* new_ind = ind->accept_mutator(this);
|
||||
if (new_ind != ind) {
|
||||
any_index_changed = true;
|
||||
}
|
||||
indices_new.push_back(new_ind);
|
||||
}
|
||||
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
|
||||
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
|
||||
if (buf == buf_new && !any_index_changed) {
|
||||
return v;
|
||||
}
|
||||
return new Load(dtype, buf_new, indices_new);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(Buf* v) {
|
||||
const Var* var = v->base_handle();
|
||||
Expr* IRMutator::mutate(Buf* v) {
|
||||
Var* var = v->base_handle();
|
||||
Var* var_new =
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
dynamic_cast<Var*>(const_cast<Expr*>(var->accept_mutator(this)));
|
||||
@ -196,9 +196,9 @@ const Expr* IRMutator::mutate(Buf* v) {
|
||||
}
|
||||
bool any_change = var_new != var;
|
||||
|
||||
std::vector<const Expr*> dims_old = v->dims();
|
||||
std::vector<const Expr*> dims_new(dims_old.size());
|
||||
for (const auto i : c10::irange(dims_old.size())) {
|
||||
std::vector<Expr*> dims_old = v->dims();
|
||||
std::vector<Expr*> dims_new(dims_old.size());
|
||||
for (auto i : c10::irange(dims_old.size())) {
|
||||
dims_new[i] = dims_old[i]->accept_mutator(this);
|
||||
any_change |= (dims_new[i] != dims_old[i]);
|
||||
}
|
||||
@ -212,23 +212,23 @@ const Expr* IRMutator::mutate(Buf* v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Broadcast* v) {
|
||||
const Expr* value = v->value();
|
||||
Expr* IRMutator::mutate(Broadcast* v) {
|
||||
Expr* value = v->value();
|
||||
int lanes = v->lanes();
|
||||
const Expr* value_new = value->accept_mutator(this);
|
||||
Expr* value_new = value->accept_mutator(this);
|
||||
if (value == value_new) {
|
||||
return v;
|
||||
}
|
||||
return new Broadcast(value_new, lanes);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const IfThenElse* v) {
|
||||
const Expr* condition = v->condition();
|
||||
const Expr* true_value = v->true_value();
|
||||
const Expr* false_value = v->false_value();
|
||||
const Expr* condition_new = condition->accept_mutator(this);
|
||||
const Expr* true_value_new = true_value->accept_mutator(this);
|
||||
const Expr* false_value_new = false_value->accept_mutator(this);
|
||||
Expr* IRMutator::mutate(IfThenElse* v) {
|
||||
Expr* condition = v->condition();
|
||||
Expr* true_value = v->true_value();
|
||||
Expr* false_value = v->false_value();
|
||||
Expr* condition_new = condition->accept_mutator(this);
|
||||
Expr* true_value_new = true_value->accept_mutator(this);
|
||||
Expr* false_value_new = false_value->accept_mutator(this);
|
||||
|
||||
if (condition == condition_new && true_value == true_value_new &&
|
||||
false_value == false_value_new) {
|
||||
@ -238,12 +238,12 @@ const Expr* IRMutator::mutate(const IfThenElse* v) {
|
||||
return new IfThenElse(condition_new, true_value_new, false_value_new);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Intrinsics* v) {
|
||||
std::vector<const Expr*> params(v->nparams());
|
||||
Expr* IRMutator::mutate(Intrinsics* v) {
|
||||
std::vector<Expr*> params(v->nparams());
|
||||
bool any_change = false;
|
||||
for (int i = 0; i < v->nparams(); i++) {
|
||||
const Expr* value = v->param(i);
|
||||
const Expr* value_new = value->accept_mutator(this);
|
||||
Expr* value = v->param(i);
|
||||
Expr* value_new = value->accept_mutator(this);
|
||||
if (value != value_new) {
|
||||
any_change = true;
|
||||
}
|
||||
@ -255,78 +255,78 @@ const Expr* IRMutator::mutate(const Intrinsics* v) {
|
||||
return new Intrinsics(v->op_type(), params);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Term* v) {
|
||||
const Expr* newScalar = v->scalar()->accept_mutator(this);
|
||||
Expr* IRMutator::mutate(Term* v) {
|
||||
Expr* newScalar = v->scalar()->accept_mutator(this);
|
||||
|
||||
std::vector<const Expr*> variables;
|
||||
for (const auto* t : v->variables()) {
|
||||
std::vector<Expr*> variables;
|
||||
for (auto* t : v->variables()) {
|
||||
variables.push_back(t->accept_mutator(this));
|
||||
}
|
||||
return new Term(v->hasher(), newScalar, variables);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const Polynomial* v) {
|
||||
const Expr* newScalar = v->scalar()->accept_mutator(this);
|
||||
Expr* IRMutator::mutate(Polynomial* v) {
|
||||
Expr* newScalar = v->scalar()->accept_mutator(this);
|
||||
|
||||
std::vector<const Term*> variables;
|
||||
for (const auto* t : v->variables()) {
|
||||
variables.push_back(static_cast<const Term*>(t->accept_mutator(this)));
|
||||
std::vector<Term*> variables;
|
||||
for (auto* t : v->variables()) {
|
||||
variables.push_back(static_cast<Term*>(t->accept_mutator(this)));
|
||||
}
|
||||
return new Polynomial(v->hasher(), newScalar, variables);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const RoundOff* v) {
|
||||
Expr* IRMutator::mutate(RoundOff* v) {
|
||||
return new RoundOff(
|
||||
v->lhs()->accept_mutator(this), v->rhs()->accept_mutator(this));
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const MaxTerm* v) {
|
||||
const Expr* newScalar = nullptr;
|
||||
Expr* IRMutator::mutate(MaxTerm* v) {
|
||||
Expr* newScalar = nullptr;
|
||||
if (v->scalar()) {
|
||||
newScalar = v->scalar()->accept_mutator(this);
|
||||
}
|
||||
|
||||
std::vector<const Expr*> variables;
|
||||
for (const auto* t : v->variables()) {
|
||||
std::vector<Expr*> variables;
|
||||
for (auto* t : v->variables()) {
|
||||
variables.push_back(t->accept_mutator(this));
|
||||
}
|
||||
return new MaxTerm(v->hasher(), newScalar, v->propagate_nans(), variables);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const MinTerm* v) {
|
||||
const Expr* newScalar = nullptr;
|
||||
Expr* IRMutator::mutate(MinTerm* v) {
|
||||
Expr* newScalar = nullptr;
|
||||
if (v->scalar()) {
|
||||
newScalar = v->scalar()->accept_mutator(this);
|
||||
}
|
||||
|
||||
std::vector<const Expr*> variables;
|
||||
for (const auto* t : v->variables()) {
|
||||
std::vector<Expr*> variables;
|
||||
for (auto* t : v->variables()) {
|
||||
variables.push_back(t->accept_mutator(this));
|
||||
}
|
||||
return new MinTerm(v->hasher(), newScalar, v->propagate_nans(), variables);
|
||||
}
|
||||
|
||||
const Expr* IRMutator::mutate(const ReduceOp* v) {
|
||||
const Expr* body_new = v->body()->accept_mutator(this);
|
||||
Expr* IRMutator::mutate(ReduceOp* v) {
|
||||
Expr* body_new = v->body()->accept_mutator(this);
|
||||
|
||||
std::vector<const Var*> new_reduce_args;
|
||||
std::vector<Var*> new_reduce_args;
|
||||
for (auto* r : v->reduce_args()) {
|
||||
new_reduce_args.push_back(static_cast<const Var*>(r->accept_mutator(this)));
|
||||
new_reduce_args.push_back(static_cast<Var*>(r->accept_mutator(this)));
|
||||
}
|
||||
|
||||
return new ReduceOp(body_new, new_reduce_args, v->reducer());
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const For* v) {
|
||||
const Expr* var = v->var();
|
||||
const Expr* start = v->start();
|
||||
const Expr* stop = v->stop();
|
||||
Stmt* IRMutator::mutate(For* v) {
|
||||
Expr* var = v->var();
|
||||
Expr* start = v->start();
|
||||
Expr* stop = v->stop();
|
||||
Stmt* body = v->body();
|
||||
LoopOptions loop_options = v->loop_options();
|
||||
const Expr* var_new_expr = var->accept_mutator(this);
|
||||
const Var* var_new = dynamic_cast<const Var*>(var_new_expr);
|
||||
const Expr* start_new = start->accept_mutator(this);
|
||||
const Expr* stop_new = stop->accept_mutator(this);
|
||||
Expr* var_new_expr = var->accept_mutator(this);
|
||||
Var* var_new = dynamic_cast<Var*>(var_new_expr);
|
||||
Expr* start_new = start->accept_mutator(this);
|
||||
Expr* stop_new = stop->accept_mutator(this);
|
||||
Stmt* body_new = body->accept_mutator(this);
|
||||
if (!body_new) {
|
||||
return nullptr;
|
||||
@ -341,7 +341,7 @@ Stmt* IRMutator::mutate(const For* v) {
|
||||
return new For(var_new, start_new, stop_new, body_new, loop_options);
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const Block* v) {
|
||||
Stmt* IRMutator::mutate(Block* v) {
|
||||
bool any_change = false;
|
||||
|
||||
std::vector<Stmt*> stmts;
|
||||
@ -362,69 +362,68 @@ Stmt* IRMutator::mutate(const Block* v) {
|
||||
return Block::make(stmts);
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const Store* v) {
|
||||
const Buf* buf = v->buf();
|
||||
Stmt* IRMutator::mutate(Store* v) {
|
||||
Buf* buf = v->buf();
|
||||
|
||||
bool any_index_changed = false;
|
||||
std::vector<const Expr*> indices_new;
|
||||
for (const Expr* ind : v->indices()) {
|
||||
const Expr* new_ind = ind->accept_mutator(this);
|
||||
std::vector<Expr*> indices_new;
|
||||
for (Expr* ind : v->indices()) {
|
||||
Expr* new_ind = ind->accept_mutator(this);
|
||||
if (new_ind != ind) {
|
||||
any_index_changed = true;
|
||||
}
|
||||
indices_new.push_back(new_ind);
|
||||
}
|
||||
const Expr* value = v->value();
|
||||
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
|
||||
const Expr* value_new = value->accept_mutator(this);
|
||||
Expr* value = v->value();
|
||||
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
|
||||
Expr* value_new = value->accept_mutator(this);
|
||||
if (buf == buf_new && !any_index_changed && value == value_new) {
|
||||
return (Stmt*)v;
|
||||
}
|
||||
return new Store(buf_new, indices_new, value_new);
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const AtomicAdd* v) {
|
||||
const Buf* buf = v->buf();
|
||||
Stmt* IRMutator::mutate(AtomicAdd* v) {
|
||||
Buf* buf = v->buf();
|
||||
|
||||
bool any_index_changed = false;
|
||||
std::vector<const Expr*> indices_new;
|
||||
for (const Expr* ind : v->indices()) {
|
||||
const Expr* new_ind = ind->accept_mutator(this);
|
||||
std::vector<Expr*> indices_new;
|
||||
for (Expr* ind : v->indices()) {
|
||||
Expr* new_ind = ind->accept_mutator(this);
|
||||
if (new_ind != ind) {
|
||||
any_index_changed = true;
|
||||
}
|
||||
indices_new.push_back(new_ind);
|
||||
}
|
||||
const Expr* value = v->value();
|
||||
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
|
||||
const Expr* value_new = value->accept_mutator(this);
|
||||
Expr* value = v->value();
|
||||
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
|
||||
Expr* value_new = value->accept_mutator(this);
|
||||
if (buf == buf_new && !any_index_changed && value == value_new) {
|
||||
return (Stmt*)v;
|
||||
}
|
||||
return new AtomicAdd(buf_new, indices_new, value_new);
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const SyncThreads* v) {
|
||||
Stmt* IRMutator::mutate(SyncThreads* v) {
|
||||
return new SyncThreads();
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const ExternalCall* v) {
|
||||
Stmt* IRMutator::mutate(ExternalCall* v) {
|
||||
bool changed = false;
|
||||
const Buf* new_buf = dynamic_cast<const Buf*>(v->buf()->accept_mutator(this));
|
||||
Buf* new_buf = dynamic_cast<Buf*>(v->buf()->accept_mutator(this));
|
||||
TORCH_INTERNAL_ASSERT(new_buf);
|
||||
changed |= new_buf != v->buf();
|
||||
|
||||
std::vector<const Buf*> new_buf_args;
|
||||
for (const Buf* buf_arg : v->buf_args()) {
|
||||
const Buf* new_buf_arg =
|
||||
dynamic_cast<const Buf*>(buf_arg->accept_mutator(this));
|
||||
std::vector<Buf*> new_buf_args;
|
||||
for (Buf* buf_arg : v->buf_args()) {
|
||||
Buf* new_buf_arg = dynamic_cast<Buf*>(buf_arg->accept_mutator(this));
|
||||
TORCH_INTERNAL_ASSERT(new_buf_arg);
|
||||
new_buf_args.push_back(new_buf_arg);
|
||||
changed |= new_buf_arg != buf_arg;
|
||||
}
|
||||
std::vector<const Expr*> new_args;
|
||||
for (const Expr* arg : v->args()) {
|
||||
const Expr* new_arg = arg->accept_mutator(this);
|
||||
std::vector<Expr*> new_args;
|
||||
for (Expr* arg : v->args()) {
|
||||
Expr* new_arg = arg->accept_mutator(this);
|
||||
new_args.push_back(new_arg);
|
||||
changed |= new_arg != arg;
|
||||
}
|
||||
@ -433,9 +432,9 @@ Stmt* IRMutator::mutate(const ExternalCall* v) {
|
||||
: (Stmt*)v;
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const Allocate* v) {
|
||||
const Buf* buf = v->buf();
|
||||
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
|
||||
Stmt* IRMutator::mutate(Allocate* v) {
|
||||
Buf* buf = v->buf();
|
||||
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
|
||||
TORCH_INTERNAL_ASSERT(buf_new);
|
||||
if (buf_new == buf) {
|
||||
return (Stmt*)v;
|
||||
@ -443,9 +442,9 @@ Stmt* IRMutator::mutate(const Allocate* v) {
|
||||
return new Allocate(buf_new);
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const Free* v) {
|
||||
const Buf* buf = v->buf();
|
||||
const Buf* buf_new = dynamic_cast<const Buf*>(buf->accept_mutator(this));
|
||||
Stmt* IRMutator::mutate(Free* v) {
|
||||
Buf* buf = v->buf();
|
||||
Buf* buf_new = dynamic_cast<Buf*>(buf->accept_mutator(this));
|
||||
TORCH_INTERNAL_ASSERT(buf_new);
|
||||
if (buf_new == buf) {
|
||||
return (Stmt*)v;
|
||||
@ -454,12 +453,12 @@ Stmt* IRMutator::mutate(const Free* v) {
|
||||
return new Free(buf_new);
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const Let* v) {
|
||||
const Var* var_old = v->var();
|
||||
const Var* var_new = dynamic_cast<const Var*>(var_old->accept_mutator(this));
|
||||
Stmt* IRMutator::mutate(Let* v) {
|
||||
Var* var_old = v->var();
|
||||
Var* var_new = dynamic_cast<Var*>(var_old->accept_mutator(this));
|
||||
|
||||
const Expr* val_old = v->value();
|
||||
const Expr* val_new = val_old->accept_mutator(this);
|
||||
Expr* val_old = v->value();
|
||||
Expr* val_new = val_old->accept_mutator(this);
|
||||
|
||||
if (var_new == var_old && val_old == val_new) {
|
||||
return (Stmt*)v;
|
||||
@ -468,12 +467,12 @@ Stmt* IRMutator::mutate(const Let* v) {
|
||||
return new Let(var_new, val_new);
|
||||
}
|
||||
|
||||
Stmt* IRMutator::mutate(const Cond* v) {
|
||||
const Expr* cond_old = v->condition();
|
||||
Stmt* IRMutator::mutate(Cond* v) {
|
||||
Expr* cond_old = v->condition();
|
||||
Stmt* true_old = v->true_stmt();
|
||||
Stmt* false_old = v->false_stmt();
|
||||
|
||||
const Expr* cond_new = cond_old->accept_mutator(this);
|
||||
Expr* cond_new = cond_old->accept_mutator(this);
|
||||
Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old;
|
||||
Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old;
|
||||
|
||||
@ -493,24 +492,24 @@ Stmt* IRMutator::mutate(const Cond* v) {
|
||||
|
||||
class StmtClone : public IRMutator {
|
||||
public:
|
||||
Stmt* mutate(const For* v) override;
|
||||
Stmt* mutate(const Block* v) override;
|
||||
Stmt* mutate(const Store* v) override;
|
||||
Stmt* mutate(const Allocate* v) override;
|
||||
Stmt* mutate(const Free* v) override;
|
||||
Stmt* mutate(const Let* v) override;
|
||||
Stmt* mutate(const Cond* v) override;
|
||||
Stmt* mutate(const AtomicAdd* v) override;
|
||||
Stmt* mutate(For* v) override;
|
||||
Stmt* mutate(Block* v) override;
|
||||
Stmt* mutate(Store* v) override;
|
||||
Stmt* mutate(Allocate* v) override;
|
||||
Stmt* mutate(Free* v) override;
|
||||
Stmt* mutate(Let* v) override;
|
||||
Stmt* mutate(Cond* v) override;
|
||||
Stmt* mutate(AtomicAdd* v) override;
|
||||
};
|
||||
|
||||
Stmt* StmtClone::mutate(const For* v) {
|
||||
Stmt* StmtClone::mutate(For* v) {
|
||||
// Only body needs to be cloned as only statements are mutable
|
||||
Stmt* body_new = v->body()->accept_mutator(this);
|
||||
|
||||
return new For(v->var(), v->start(), v->stop(), body_new, v->loop_options());
|
||||
}
|
||||
|
||||
Stmt* StmtClone::mutate(const Block* v) {
|
||||
Stmt* StmtClone::mutate(Block* v) {
|
||||
std::vector<Stmt*> stmts;
|
||||
for (Stmt* stmt : *v) {
|
||||
stmts.push_back(stmt->accept_mutator(this));
|
||||
@ -518,27 +517,27 @@ Stmt* StmtClone::mutate(const Block* v) {
|
||||
return new Block(stmts);
|
||||
}
|
||||
|
||||
Stmt* StmtClone::mutate(const Store* v) {
|
||||
Stmt* StmtClone::mutate(Store* v) {
|
||||
return new Store(v->buf(), v->indices(), v->value());
|
||||
}
|
||||
|
||||
Stmt* StmtClone::mutate(const AtomicAdd* v) {
|
||||
Stmt* StmtClone::mutate(AtomicAdd* v) {
|
||||
return new AtomicAdd(v->buf(), v->indices(), v->value());
|
||||
}
|
||||
|
||||
Stmt* StmtClone::mutate(const Allocate* v) {
|
||||
Stmt* StmtClone::mutate(Allocate* v) {
|
||||
return new Allocate(v->buf());
|
||||
}
|
||||
|
||||
Stmt* StmtClone::mutate(const Free* v) {
|
||||
Stmt* StmtClone::mutate(Free* v) {
|
||||
return new Free(v->buf());
|
||||
}
|
||||
|
||||
Stmt* StmtClone::mutate(const Let* v) {
|
||||
Stmt* StmtClone::mutate(Let* v) {
|
||||
return new Let(v->var(), v->value());
|
||||
}
|
||||
|
||||
Stmt* StmtClone::mutate(const Cond* v) {
|
||||
Stmt* StmtClone::mutate(Cond* v) {
|
||||
Stmt* true_old = v->true_stmt();
|
||||
Stmt* false_old = v->false_stmt();
|
||||
|
||||
|
@ -57,52 +57,51 @@ class ExternalCall;
|
||||
class TORCH_API IRMutator {
|
||||
public:
|
||||
virtual ~IRMutator() = default;
|
||||
virtual const Expr* mutate(const Add* v);
|
||||
virtual const Expr* mutate(const Sub* v);
|
||||
virtual const Expr* mutate(const Mul* v);
|
||||
virtual const Expr* mutate(const Div* v);
|
||||
virtual const Expr* mutate(const Mod* v);
|
||||
virtual const Expr* mutate(const Max* v);
|
||||
virtual const Expr* mutate(const Min* v);
|
||||
virtual const Expr* mutate(const And* v);
|
||||
virtual const Expr* mutate(const Or* v);
|
||||
virtual const Expr* mutate(const Xor* v);
|
||||
virtual const Expr* mutate(const Lshift* v);
|
||||
virtual const Expr* mutate(const Rshift* v);
|
||||
virtual const Expr* mutate(const CompareSelect* v);
|
||||
#define IMM_MUTATE_DECLARE(Type, Name) \
|
||||
virtual const Expr* mutate(const Name##Imm* v);
|
||||
virtual Expr* mutate(Add* v);
|
||||
virtual Expr* mutate(Sub* v);
|
||||
virtual Expr* mutate(Mul* v);
|
||||
virtual Expr* mutate(Div* v);
|
||||
virtual Expr* mutate(Mod* v);
|
||||
virtual Expr* mutate(Max* v);
|
||||
virtual Expr* mutate(Min* v);
|
||||
virtual Expr* mutate(And* v);
|
||||
virtual Expr* mutate(Or* v);
|
||||
virtual Expr* mutate(Xor* v);
|
||||
virtual Expr* mutate(Lshift* v);
|
||||
virtual Expr* mutate(Rshift* v);
|
||||
virtual Expr* mutate(CompareSelect* v);
|
||||
#define IMM_MUTATE_DECLARE(Type, Name) virtual Expr* mutate(Name##Imm* v);
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE);
|
||||
#undef IMM_MUTATE_DECLARE
|
||||
virtual const Expr* mutate(const Cast* v);
|
||||
virtual const Expr* mutate(const BitCast* v);
|
||||
virtual const Expr* mutate(const Var* v);
|
||||
virtual const Expr* mutate(Buf* v);
|
||||
virtual const Expr* mutate(const Ramp* v);
|
||||
virtual const Expr* mutate(const Load* v);
|
||||
virtual const Expr* mutate(const Broadcast* v);
|
||||
virtual const Expr* mutate(const IfThenElse* v);
|
||||
virtual const Expr* mutate(const Intrinsics* v);
|
||||
virtual Expr* mutate(Cast* v);
|
||||
virtual Expr* mutate(BitCast* v);
|
||||
virtual Expr* mutate(Var* v);
|
||||
virtual Expr* mutate(Buf* v);
|
||||
virtual Expr* mutate(Ramp* v);
|
||||
virtual Expr* mutate(Load* v);
|
||||
virtual Expr* mutate(Broadcast* v);
|
||||
virtual Expr* mutate(IfThenElse* v);
|
||||
virtual Expr* mutate(Intrinsics* v);
|
||||
|
||||
virtual const Expr* mutate(const Term* v);
|
||||
virtual const Expr* mutate(const Polynomial* v);
|
||||
virtual const Expr* mutate(const RoundOff* v);
|
||||
virtual const Expr* mutate(const MaxTerm* v);
|
||||
virtual const Expr* mutate(const MinTerm* v);
|
||||
virtual Expr* mutate(Term* v);
|
||||
virtual Expr* mutate(Polynomial* v);
|
||||
virtual Expr* mutate(RoundOff* v);
|
||||
virtual Expr* mutate(MaxTerm* v);
|
||||
virtual Expr* mutate(MinTerm* v);
|
||||
|
||||
virtual const Expr* mutate(const ReduceOp* v);
|
||||
virtual Expr* mutate(ReduceOp* v);
|
||||
|
||||
virtual Stmt* mutate(const For* v);
|
||||
virtual Stmt* mutate(const Block* v);
|
||||
virtual Stmt* mutate(const Store* v);
|
||||
virtual Stmt* mutate(const AtomicAdd* v);
|
||||
virtual Stmt* mutate(const SyncThreads* v);
|
||||
virtual Stmt* mutate(const ExternalCall* v);
|
||||
virtual Stmt* mutate(For* v);
|
||||
virtual Stmt* mutate(Block* v);
|
||||
virtual Stmt* mutate(Store* v);
|
||||
virtual Stmt* mutate(AtomicAdd* v);
|
||||
virtual Stmt* mutate(SyncThreads* v);
|
||||
virtual Stmt* mutate(ExternalCall* v);
|
||||
|
||||
virtual Stmt* mutate(const Allocate* v);
|
||||
virtual Stmt* mutate(const Free* v);
|
||||
virtual Stmt* mutate(const Let* v);
|
||||
virtual Stmt* mutate(const Cond* v);
|
||||
virtual Stmt* mutate(Allocate* v);
|
||||
virtual Stmt* mutate(Free* v);
|
||||
virtual Stmt* mutate(Let* v);
|
||||
virtual Stmt* mutate(Cond* v);
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -18,11 +18,11 @@ void IRPrinter::print(ExprHandle expr) {
|
||||
expr.node()->accept(this);
|
||||
}
|
||||
|
||||
void IRPrinter::print(const Expr& expr) {
|
||||
void IRPrinter::print(Expr& expr) {
|
||||
expr.accept(this);
|
||||
}
|
||||
|
||||
void IRPrinter::print(const Stmt& stmt) {
|
||||
void IRPrinter::print(Stmt& stmt) {
|
||||
stmt.accept(this);
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ void IRPrinter::print(const Stmt& stmt) {
|
||||
// we need to look at the operator precedence to make the output simpler.
|
||||
template <typename Op>
|
||||
void visitBinaryOp(
|
||||
const BinaryOpNode<Op>* v,
|
||||
BinaryOpNode<Op>* v,
|
||||
const std::string& op_str,
|
||||
IRPrinter* printer,
|
||||
bool parens = true) {
|
||||
@ -58,43 +58,43 @@ void visitBinaryOp(
|
||||
}
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Add* v) {
|
||||
void IRPrinter::visit(Add* v) {
|
||||
visitBinaryOp(v, "+", this);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Sub* v) {
|
||||
void IRPrinter::visit(Sub* v) {
|
||||
visitBinaryOp(v, "-", this);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Mul* v) {
|
||||
void IRPrinter::visit(Mul* v) {
|
||||
visitBinaryOp(v, "*", this);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Div* v) {
|
||||
void IRPrinter::visit(Div* v) {
|
||||
visitBinaryOp(v, "/", this);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const And* v) {
|
||||
void IRPrinter::visit(And* v) {
|
||||
visitBinaryOp(v, "&", this);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Or* v) {
|
||||
void IRPrinter::visit(Or* v) {
|
||||
visitBinaryOp(v, "|", this);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Xor* v) {
|
||||
void IRPrinter::visit(Xor* v) {
|
||||
visitBinaryOp(v, "^", this);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Lshift* v) {
|
||||
void IRPrinter::visit(Lshift* v) {
|
||||
visitBinaryOp(v, "<<", this);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Rshift* v) {
|
||||
void IRPrinter::visit(Rshift* v) {
|
||||
visitBinaryOp(v, ">>", this);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Mod* v) {
|
||||
void IRPrinter::visit(Mod* v) {
|
||||
if (v->dtype().is_integral()) {
|
||||
visitBinaryOp(v, "%", this);
|
||||
} else if (v->dtype().is_floating_point()) {
|
||||
@ -104,7 +104,7 @@ void IRPrinter::visit(const Mod* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Max* v) {
|
||||
void IRPrinter::visit(Max* v) {
|
||||
os() << "Max(";
|
||||
v->lhs()->accept(this);
|
||||
os() << ", ";
|
||||
@ -112,7 +112,7 @@ void IRPrinter::visit(const Max* v) {
|
||||
os() << ", " << (unsigned int)v->propagate_nans() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Min* v) {
|
||||
void IRPrinter::visit(Min* v) {
|
||||
os() << "Min(";
|
||||
v->lhs()->accept(this);
|
||||
os() << ", ";
|
||||
@ -120,7 +120,7 @@ void IRPrinter::visit(const Min* v) {
|
||||
os() << ", " << (unsigned int)v->propagate_nans() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const CompareSelect* v) {
|
||||
void IRPrinter::visit(CompareSelect* v) {
|
||||
CompareSelectOperation cmp_op = v->compare_select_op();
|
||||
int self_prec = getPrecedence(v->expr_type());
|
||||
int lhs_prec = getPrecedence(v->lhs()->expr_type());
|
||||
@ -165,7 +165,7 @@ void IRPrinter::visit(const CompareSelect* v) {
|
||||
}
|
||||
os() << " ? ";
|
||||
|
||||
auto withParens = [&](const Expr* e) {
|
||||
auto withParens = [&](Expr* e) {
|
||||
auto prec = getPrecedence(e->expr_type());
|
||||
if (prec >= self_prec) {
|
||||
os() << "(";
|
||||
@ -219,30 +219,30 @@ static void formatImm(std::ostream& os, T v) {
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT);
|
||||
#undef IMM_PRINT_VISIT
|
||||
|
||||
void IRPrinter::visit(const Cast* v) {
|
||||
void IRPrinter::visit(Cast* v) {
|
||||
auto dtype = v->dtype();
|
||||
os() << dtypeToCppString(dtype) << "(";
|
||||
v->src_value()->accept(this);
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Var* v) {
|
||||
void IRPrinter::visit(Var* v) {
|
||||
os() << name_manager_.get_unique_name(v);
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Ramp* v) {
|
||||
void IRPrinter::visit(Ramp* v) {
|
||||
os() << "Ramp(" << *v->base() << ", " << *v->stride() << ", " << v->lanes()
|
||||
<< ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Load* v) {
|
||||
void IRPrinter::visit(Load* v) {
|
||||
// TODO: support the mask case
|
||||
if (v->indices().size() == 0) {
|
||||
os() << *v->base_handle();
|
||||
} else {
|
||||
os() << *v->base_handle() << "[";
|
||||
size_t i = 0;
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
if (i++) {
|
||||
os() << ", ";
|
||||
}
|
||||
@ -255,18 +255,18 @@ void IRPrinter::visit(const Load* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Broadcast* v) {
|
||||
void IRPrinter::visit(Broadcast* v) {
|
||||
os() << "Broadcast(" << *v->value() << ", " << v->lanes() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const IfThenElse* v) {
|
||||
void IRPrinter::visit(IfThenElse* v) {
|
||||
os() << "IfThenElse(" << *v->condition() << ", " << *v->true_value() << ", "
|
||||
<< *v->false_value() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Intrinsics* v) {
|
||||
void IRPrinter::visit(Intrinsics* v) {
|
||||
os() << v->func_name() << "(";
|
||||
for (const auto i : c10::irange(v->nparams())) {
|
||||
for (auto i : c10::irange(v->nparams())) {
|
||||
if (i > 0) {
|
||||
os() << ", ";
|
||||
}
|
||||
@ -275,7 +275,7 @@ void IRPrinter::visit(const Intrinsics* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Term* v) {
|
||||
void IRPrinter::visit(Term* v) {
|
||||
os() << "Term(";
|
||||
v->scalar()->accept(this);
|
||||
for (auto* t : v->variables()) {
|
||||
@ -285,7 +285,7 @@ void IRPrinter::visit(const Term* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Polynomial* v) {
|
||||
void IRPrinter::visit(Polynomial* v) {
|
||||
bool first = true;
|
||||
os() << "Polynomial(";
|
||||
for (auto* t : v->variables()) {
|
||||
@ -303,7 +303,7 @@ void IRPrinter::visit(const Polynomial* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const RoundOff* v) {
|
||||
void IRPrinter::visit(RoundOff* v) {
|
||||
os() << "RoundOff(";
|
||||
v->lhs()->accept(this);
|
||||
os() << ", ";
|
||||
@ -311,7 +311,7 @@ void IRPrinter::visit(const RoundOff* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const MaxTerm* v) {
|
||||
void IRPrinter::visit(MaxTerm* v) {
|
||||
os() << "MaxTerm(";
|
||||
if (v->scalar()) {
|
||||
v->scalar()->accept(this);
|
||||
@ -326,7 +326,7 @@ void IRPrinter::visit(const MaxTerm* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const MinTerm* v) {
|
||||
void IRPrinter::visit(MinTerm* v) {
|
||||
os() << "MinTerm(";
|
||||
if (v->scalar()) {
|
||||
v->scalar()->accept(this);
|
||||
@ -341,7 +341,7 @@ void IRPrinter::visit(const MinTerm* v) {
|
||||
os() << ")";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const ReduceOp* v) {
|
||||
void IRPrinter::visit(ReduceOp* v) {
|
||||
os() << "ReduceOp(";
|
||||
os() << *v->body() << ", ";
|
||||
|
||||
@ -363,7 +363,7 @@ void IRPrinter::visit(const ReduceOp* v) {
|
||||
// each statement in a `Block` the printer will insert indentation before
|
||||
// the statement and a newline after the statement.
|
||||
|
||||
void IRPrinter::visit(const Store* v) {
|
||||
void IRPrinter::visit(Store* v) {
|
||||
// TODO: handle the mask
|
||||
if (v->indices().size() == 0) {
|
||||
os() << *v->base_handle() << " = " << *v->value() << ";";
|
||||
@ -372,7 +372,7 @@ void IRPrinter::visit(const Store* v) {
|
||||
|
||||
os() << *v->base_handle() << "[";
|
||||
size_t i = 0;
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
if (i++) {
|
||||
os() << ", ";
|
||||
}
|
||||
@ -384,8 +384,8 @@ void IRPrinter::visit(const Store* v) {
|
||||
os() << "] = " << *v->value() << ";";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const For* v) {
|
||||
const Var* var = v->var();
|
||||
void IRPrinter::visit(For* v) {
|
||||
Var* var = v->var();
|
||||
VarHandle vv(var);
|
||||
os() << "for (" << dtypeToCppString(var->dtype()) << " " << vv << " = "
|
||||
<< ExprHandle(v->start()) << "; " << vv << " < " << ExprHandle(v->stop())
|
||||
@ -401,7 +401,7 @@ void IRPrinter::visit(const For* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Block* v) {
|
||||
void IRPrinter::visit(Block* v) {
|
||||
os() << "{\n";
|
||||
indent_++;
|
||||
|
||||
@ -414,12 +414,12 @@ void IRPrinter::visit(const Block* v) {
|
||||
os() << "}";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Allocate* v) {
|
||||
void IRPrinter::visit(Allocate* v) {
|
||||
os() << "Allocate(" << *v->buffer_var()
|
||||
<< "); // dtype=" << dtypeToCppString(v->dtype());
|
||||
os() << ", dims=[";
|
||||
const std::vector<const Expr*>& dims = v->dims();
|
||||
for (const auto i : c10::irange(dims.size())) {
|
||||
const std::vector<Expr*>& dims = v->dims();
|
||||
for (auto i : c10::irange(dims.size())) {
|
||||
if (i != 0) {
|
||||
os() << ", ";
|
||||
}
|
||||
@ -428,18 +428,18 @@ void IRPrinter::visit(const Allocate* v) {
|
||||
os() << "]";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Free* v) {
|
||||
void IRPrinter::visit(Free* v) {
|
||||
os() << "Free(" << *v->buffer_var() << ");";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Let* v) {
|
||||
void IRPrinter::visit(Let* v) {
|
||||
os() << dtypeToCppString(v->dtype()) << " " << *v->var();
|
||||
os() << " = " << *v->value();
|
||||
os() << ";";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const Cond* v) {
|
||||
const Expr* cond = v->condition();
|
||||
void IRPrinter::visit(Cond* v) {
|
||||
Expr* cond = v->condition();
|
||||
Stmt* true_stmt = v->true_stmt();
|
||||
Stmt* false_stmt = v->false_stmt();
|
||||
if (!true_stmt) {
|
||||
@ -455,10 +455,10 @@ void IRPrinter::visit(const Cond* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const AtomicAdd* v) {
|
||||
void IRPrinter::visit(AtomicAdd* v) {
|
||||
os() << "atomicAdd(&" << *v->base_handle() << "[";
|
||||
size_t i = 0;
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
if (i++) {
|
||||
os() << ", ";
|
||||
}
|
||||
@ -470,16 +470,16 @@ void IRPrinter::visit(const AtomicAdd* v) {
|
||||
os() << "], " << *v->value() << ");";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const SyncThreads* v) {
|
||||
void IRPrinter::visit(SyncThreads* v) {
|
||||
os() << "__syncthreads();";
|
||||
}
|
||||
|
||||
void IRPrinter::visit(const ExternalCall* v) {
|
||||
void IRPrinter::visit(ExternalCall* v) {
|
||||
os() << *v->buf() << " = " << v->func_name() << "(";
|
||||
|
||||
os() << "buf_args={";
|
||||
int i = 0;
|
||||
for (const Buf* buf_arg : v->buf_args()) {
|
||||
for (Buf* buf_arg : v->buf_args()) {
|
||||
if (i++ > 0) {
|
||||
os() << ", ";
|
||||
}
|
||||
@ -488,7 +488,7 @@ void IRPrinter::visit(const ExternalCall* v) {
|
||||
|
||||
os() << "}, args={";
|
||||
i = 0;
|
||||
for (const Expr* arg : v->args()) {
|
||||
for (Expr* arg : v->args()) {
|
||||
if (i++ > 0) {
|
||||
os() << ", ";
|
||||
}
|
||||
@ -504,11 +504,12 @@ void IRPrinter::emitIndent() {
|
||||
std::ostream& operator<<(std::ostream& stream, const ExprHandle& expr) {
|
||||
IRPrinter::PrinterStream* printer_stream =
|
||||
dynamic_cast<IRPrinter::PrinterStream*>(&stream);
|
||||
ExprHandle& mutable_expr = const_cast<ExprHandle&>(expr);
|
||||
if (printer_stream != nullptr) {
|
||||
expr.node()->accept(printer_stream->printer());
|
||||
mutable_expr.node()->accept(printer_stream->printer());
|
||||
} else {
|
||||
IRPrinter p(stream);
|
||||
p.print(expr);
|
||||
p.print(mutable_expr);
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
@ -516,11 +517,12 @@ std::ostream& operator<<(std::ostream& stream, const ExprHandle& expr) {
|
||||
std::ostream& operator<<(std::ostream& stream, const Expr& expr) {
|
||||
IRPrinter::PrinterStream* printer_stream =
|
||||
dynamic_cast<IRPrinter::PrinterStream*>(&stream);
|
||||
Expr& mutable_expr = const_cast<Expr&>(expr);
|
||||
if (printer_stream != nullptr) {
|
||||
expr.accept(printer_stream->printer());
|
||||
mutable_expr.accept(printer_stream->printer());
|
||||
} else {
|
||||
IRPrinter p(stream);
|
||||
p.print(expr);
|
||||
p.print(mutable_expr);
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
@ -528,11 +530,12 @@ std::ostream& operator<<(std::ostream& stream, const Expr& expr) {
|
||||
std::ostream& operator<<(std::ostream& stream, const Stmt& stmt) {
|
||||
IRPrinter::PrinterStream* printer_stream =
|
||||
dynamic_cast<IRPrinter::PrinterStream*>(&stream);
|
||||
Stmt& mutable_stmt = const_cast<Stmt&>(stmt);
|
||||
if (printer_stream != nullptr) {
|
||||
stmt.accept(printer_stream->printer());
|
||||
mutable_stmt.accept(printer_stream->printer());
|
||||
} else {
|
||||
IRPrinter p(stream);
|
||||
p.print(stmt);
|
||||
p.print(mutable_stmt);
|
||||
}
|
||||
return stream;
|
||||
}
|
||||
@ -544,8 +547,9 @@ std::ostream& operator<<(std::ostream& stream, const Tensor& t) {
|
||||
|
||||
void print(const Expr* expr) {
|
||||
if (expr) {
|
||||
Expr* mutable_expr = const_cast<Expr*>(expr);
|
||||
IRPrinter p(std::cout);
|
||||
p.print(*expr);
|
||||
p.print(*mutable_expr);
|
||||
} else {
|
||||
std::cout << "(null expr)";
|
||||
}
|
||||
@ -554,8 +558,9 @@ void print(const Expr* expr) {
|
||||
|
||||
void print(const Stmt* stmt) {
|
||||
if (stmt) {
|
||||
Stmt* mutable_stmt = const_cast<Stmt*>(stmt);
|
||||
IRPrinter p(std::cout);
|
||||
p.print(*stmt);
|
||||
p.print(*mutable_stmt);
|
||||
} else {
|
||||
std::cout << "(null stmt)\n";
|
||||
}
|
||||
@ -589,7 +594,7 @@ std::string to_string(const Tensor* t) {
|
||||
std::ostringstream oss;
|
||||
// TODO: move this to Buf printer
|
||||
oss << "Tensor " << t->buf()->name_hint() << "[";
|
||||
for (const auto i : c10::irange(t->buf()->ndim())) {
|
||||
for (auto i : c10::irange(t->buf()->ndim())) {
|
||||
if (i != 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
|
@ -17,48 +17,48 @@ class TORCH_API IRPrinter : public IRVisitor {
|
||||
explicit IRPrinter(std::ostream& os) : printer_os_(this, os) {}
|
||||
|
||||
void print(ExprHandle);
|
||||
void print(const Expr&);
|
||||
void print(const Stmt&);
|
||||
void visit(const Add* v) override;
|
||||
void visit(const Sub* v) override;
|
||||
void visit(const Mul* v) override;
|
||||
void visit(const Div* v) override;
|
||||
void visit(const Mod* v) override;
|
||||
void visit(const Max* v) override;
|
||||
void visit(const Min* v) override;
|
||||
void visit(const And* v) override;
|
||||
void visit(const Or* v) override;
|
||||
void visit(const Xor* v) override;
|
||||
void visit(const Lshift* v) override;
|
||||
void visit(const Rshift* v) override;
|
||||
void visit(const CompareSelect* v) override;
|
||||
void print(Expr&);
|
||||
void print(Stmt&);
|
||||
void visit(Add* v) override;
|
||||
void visit(Sub* v) override;
|
||||
void visit(Mul* v) override;
|
||||
void visit(Div* v) override;
|
||||
void visit(Mod* v) override;
|
||||
void visit(Max* v) override;
|
||||
void visit(Min* v) override;
|
||||
void visit(And* v) override;
|
||||
void visit(Or* v) override;
|
||||
void visit(Xor* v) override;
|
||||
void visit(Lshift* v) override;
|
||||
void visit(Rshift* v) override;
|
||||
void visit(CompareSelect* v) override;
|
||||
#define IMM_PRINT_VISIT(Type, Name) void visit(const Name##Imm* v) override;
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT);
|
||||
#undef IMM_PRINT_VISIT
|
||||
void visit(const Cast* v) override;
|
||||
void visit(const Var* v) override;
|
||||
void visit(const Ramp* v) override;
|
||||
void visit(const Load* v) override;
|
||||
void visit(const Broadcast* v) override;
|
||||
void visit(const IfThenElse* v) override;
|
||||
void visit(const Intrinsics* v) override;
|
||||
void visit(const Term* v) override;
|
||||
void visit(const Polynomial* v) override;
|
||||
void visit(const RoundOff* v) override;
|
||||
void visit(const MaxTerm* v) override;
|
||||
void visit(const MinTerm* v) override;
|
||||
void visit(const ReduceOp* v) override;
|
||||
void visit(Cast* v) override;
|
||||
void visit(Var* v) override;
|
||||
void visit(Ramp* v) override;
|
||||
void visit(Load* v) override;
|
||||
void visit(Broadcast* v) override;
|
||||
void visit(IfThenElse* v) override;
|
||||
void visit(Intrinsics* v) override;
|
||||
void visit(Term* v) override;
|
||||
void visit(Polynomial* v) override;
|
||||
void visit(RoundOff* v) override;
|
||||
void visit(MaxTerm* v) override;
|
||||
void visit(MinTerm* v) override;
|
||||
void visit(ReduceOp* v) override;
|
||||
|
||||
void visit(const AtomicAdd* v) override;
|
||||
void visit(const SyncThreads* v) override;
|
||||
void visit(const ExternalCall* v) override;
|
||||
void visit(const Store* v) override;
|
||||
void visit(const For* v) override;
|
||||
void visit(const Cond* v) override;
|
||||
void visit(const Block* v) override;
|
||||
void visit(const Allocate* v) override;
|
||||
void visit(const Free* v) override;
|
||||
void visit(const Let* v) override;
|
||||
void visit(AtomicAdd* v) override;
|
||||
void visit(SyncThreads* v) override;
|
||||
void visit(ExternalCall* v) override;
|
||||
void visit(Store* v) override;
|
||||
void visit(For* v) override;
|
||||
void visit(Cond* v) override;
|
||||
void visit(Block* v) override;
|
||||
void visit(Allocate* v) override;
|
||||
void visit(Free* v) override;
|
||||
void visit(Let* v) override;
|
||||
|
||||
// A child class may have a difference rule for generating dtype
|
||||
// string, e.g. CUDA needs int64_t to be generated as long long.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -25,7 +25,7 @@ namespace tensorexpr {
|
||||
// A bunch of helpers for determine the Dtype of the output of a multi argument
|
||||
// Term or Polynomial.
|
||||
template <class ExprType>
|
||||
Dtype promoteTypesVec(const Expr* s, std::vector<const ExprType*>& v) {
|
||||
Dtype promoteTypesVec(Expr* s, std::vector<ExprType*>& v) {
|
||||
Dtype t = s->dtype();
|
||||
bool first = true;
|
||||
|
||||
@ -40,7 +40,7 @@ Dtype promoteTypesVec(const Expr* s, std::vector<const ExprType*>& v) {
|
||||
}
|
||||
|
||||
template <class ExprType>
|
||||
Dtype promoteTypesVec(std::vector<const ExprType*>& v) {
|
||||
Dtype promoteTypesVec(std::vector<ExprType*>& v) {
|
||||
if (v.empty()) {
|
||||
throw malformed_input("empty list of types");
|
||||
}
|
||||
@ -54,8 +54,8 @@ Dtype promoteTypesVec(std::vector<const ExprType*>& v) {
|
||||
|
||||
template <class ExprType>
|
||||
Dtype promoteTypesMap(
|
||||
const Expr* s,
|
||||
std::unordered_map<SimplifierHashType, const ExprType*>& m) {
|
||||
Expr* s,
|
||||
std::unordered_map<SimplifierHashType, ExprType*>& m) {
|
||||
Dtype t = s->dtype();
|
||||
bool first = true;
|
||||
for (auto& e : m) {
|
||||
@ -69,12 +69,12 @@ Dtype promoteTypesMap(
|
||||
}
|
||||
|
||||
template <class ExprType>
|
||||
Dtype promoteTypesVar(const ExprType* e) {
|
||||
Dtype promoteTypesVar(ExprType* e) {
|
||||
return e->dtype();
|
||||
}
|
||||
|
||||
template <class ExprType, class... Args>
|
||||
Dtype promoteTypesVar(const ExprType* e, Args... es) {
|
||||
Dtype promoteTypesVar(ExprType* e, Args... es) {
|
||||
Dtype lhs = e->dtype();
|
||||
Dtype rhs = promoteTypesVar(es...);
|
||||
if (e->isConstant()) {
|
||||
@ -85,10 +85,10 @@ Dtype promoteTypesVar(const ExprType* e, Args... es) {
|
||||
}
|
||||
|
||||
// Creates a new Expr of the given type with the provided lhs and rhs.
|
||||
inline const Expr* newBinaryOpOfType(
|
||||
inline Expr* newBinaryOpOfType(
|
||||
IRNodeType expr_type,
|
||||
const Expr* lhs,
|
||||
const Expr* rhs,
|
||||
Expr* lhs,
|
||||
Expr* rhs,
|
||||
bool option) {
|
||||
switch (expr_type) {
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
@ -123,7 +123,7 @@ inline const Expr* newBinaryOpOfType(
|
||||
// Uses the evaluator to fold an Expression with constant terms.
|
||||
// E.g. evaluateOp(Add(3, 4)) => 7.
|
||||
// Expr v must not have any unbound Vars.
|
||||
inline Expr* evaluateOp(const Expr* v) {
|
||||
inline Expr* evaluateOp(Expr* v) {
|
||||
ExprHandle handle(v);
|
||||
ExprEval<SimpleIREvaluator> eval(handle);
|
||||
|
||||
@ -148,7 +148,7 @@ class Term : public ExprNode<Term> {
|
||||
public:
|
||||
template <class... Args>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Term(HashProvider& hasher, const Expr* s, Args... ts)
|
||||
Term(HashProvider& hasher, Expr* s, Args... ts)
|
||||
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
|
||||
CHECK(s->isConstant());
|
||||
addComponent(ts...);
|
||||
@ -156,7 +156,7 @@ class Term : public ExprNode<Term> {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Term(HashProvider& hasher, const Expr* s, std::vector<const Expr*> v)
|
||||
Term(HashProvider& hasher, Expr* s, std::vector<Expr*> v)
|
||||
: ExprNodeBase(promoteTypesVec(s, v)),
|
||||
variables_(std::move(v)),
|
||||
scalar_(s),
|
||||
@ -168,8 +168,8 @@ class Term : public ExprNode<Term> {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Term(
|
||||
HashProvider& hasher,
|
||||
const Expr* s,
|
||||
std::unordered_map<SimplifierHashType, const Expr*> varmap)
|
||||
Expr* s,
|
||||
std::unordered_map<SimplifierHashType, Expr*> varmap)
|
||||
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
|
||||
for (auto& p : varmap) {
|
||||
addComponent(p.second);
|
||||
@ -177,10 +177,10 @@ class Term : public ExprNode<Term> {
|
||||
sort();
|
||||
}
|
||||
|
||||
const Expr* scalar() const {
|
||||
Expr* scalar() const {
|
||||
return scalar_;
|
||||
}
|
||||
const std::vector<const Expr*>& variables() const {
|
||||
const std::vector<Expr*>& variables() const {
|
||||
return variables_;
|
||||
}
|
||||
HashProvider& hasher() const {
|
||||
@ -192,16 +192,16 @@ class Term : public ExprNode<Term> {
|
||||
SimplifierHashType hashVars() const;
|
||||
|
||||
private:
|
||||
std::vector<const Expr*> variables_;
|
||||
const Expr* scalar_;
|
||||
std::vector<Expr*> variables_;
|
||||
Expr* scalar_;
|
||||
HashProvider& hasher_;
|
||||
|
||||
void addComponent() {}
|
||||
void addComponent(const Expr* e) {
|
||||
void addComponent(Expr* e) {
|
||||
variables_.push_back(e);
|
||||
}
|
||||
template <class... Es>
|
||||
void addComponent(const Expr* e, Es... es) {
|
||||
void addComponent(Expr* e, Es... es) {
|
||||
addComponent(e);
|
||||
addComponent(es...);
|
||||
}
|
||||
@ -217,7 +217,7 @@ class Polynomial : public ExprNode<Polynomial> {
|
||||
public:
|
||||
template <class... Args>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Polynomial(HashProvider& hasher, const Expr* s, Args... ts)
|
||||
Polynomial(HashProvider& hasher, Expr* s, Args... ts)
|
||||
: ExprNodeBase(promoteTypesVar(s, ts...)), scalar_(s), hasher_(hasher) {
|
||||
CHECK(s->isConstant());
|
||||
addTerm(ts...);
|
||||
@ -225,7 +225,7 @@ class Polynomial : public ExprNode<Polynomial> {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Polynomial(HashProvider& hasher, const Expr* s, std::vector<const Term*> v)
|
||||
Polynomial(HashProvider& hasher, Expr* s, std::vector<Term*> v)
|
||||
: ExprNodeBase(promoteTypesVec(s, v)),
|
||||
variables_(std::move(v)),
|
||||
scalar_(s),
|
||||
@ -235,7 +235,7 @@ class Polynomial : public ExprNode<Polynomial> {
|
||||
|
||||
// Helper constructor for list of terms with no scalar component.
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Polynomial(HashProvider& hasher, std::vector<const Term*> terms)
|
||||
Polynomial(HashProvider& hasher, std::vector<Term*> terms)
|
||||
: ExprNodeBase(promoteTypesVec(terms)),
|
||||
variables_(std::move(terms)),
|
||||
scalar_(getImmediateByType(dtype(), 0)),
|
||||
@ -248,8 +248,8 @@ class Polynomial : public ExprNode<Polynomial> {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Polynomial(
|
||||
HashProvider& hasher,
|
||||
const Expr* s,
|
||||
std::unordered_map<SimplifierHashType, const Term*> varmap)
|
||||
Expr* s,
|
||||
std::unordered_map<SimplifierHashType, Term*> varmap)
|
||||
: ExprNodeBase(promoteTypesMap(s, varmap)), scalar_(s), hasher_(hasher) {
|
||||
for (auto& p : varmap) {
|
||||
addTerm(p.second);
|
||||
@ -257,10 +257,10 @@ class Polynomial : public ExprNode<Polynomial> {
|
||||
sort();
|
||||
}
|
||||
|
||||
const Expr* scalar() const {
|
||||
Expr* scalar() const {
|
||||
return scalar_;
|
||||
}
|
||||
const std::vector<const Term*>& variables() const {
|
||||
const std::vector<Term*>& variables() const {
|
||||
return variables_;
|
||||
}
|
||||
HashProvider& hasher() const {
|
||||
@ -270,15 +270,15 @@ class Polynomial : public ExprNode<Polynomial> {
|
||||
SimplifierHashType hashVars() const;
|
||||
|
||||
private:
|
||||
std::vector<const Term*> variables_;
|
||||
const Expr* scalar_;
|
||||
std::vector<Term*> variables_;
|
||||
Expr* scalar_;
|
||||
HashProvider& hasher_;
|
||||
|
||||
void addTerm(const Term* t) {
|
||||
void addTerm(Term* t) {
|
||||
variables_.push_back(t);
|
||||
}
|
||||
template <class... Ts>
|
||||
void addTerm(const Term* t, Ts... ts) {
|
||||
void addTerm(Term* t, Ts... ts) {
|
||||
addTerm(t);
|
||||
addTerm(ts...);
|
||||
}
|
||||
@ -289,15 +289,14 @@ class Polynomial : public ExprNode<Polynomial> {
|
||||
|
||||
class RoundOff : public BinaryOpNode<RoundOff> {
|
||||
public:
|
||||
RoundOff(const Expr* lhs, const Expr* rhs)
|
||||
: BinaryOpNode(lhs, rhs, IRNodeType::kOther) {}
|
||||
RoundOff(Expr* lhs, Expr* rhs) : BinaryOpNode(lhs, rhs, IRNodeType::kOther) {}
|
||||
};
|
||||
|
||||
class MaxTerm : public ExprNode<MaxTerm> {
|
||||
public:
|
||||
template <class... Args>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
MaxTerm(HashProvider& hasher, const Expr* s, bool p, Args... ts)
|
||||
MaxTerm(HashProvider& hasher, Expr* s, bool p, Args... ts)
|
||||
: ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
|
||||
scalar_(s),
|
||||
hasher_(hasher),
|
||||
@ -307,11 +306,7 @@ class MaxTerm : public ExprNode<MaxTerm> {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
MaxTerm(
|
||||
HashProvider& hasher,
|
||||
const Expr* s,
|
||||
bool p,
|
||||
std::vector<const Expr*> v)
|
||||
MaxTerm(HashProvider& hasher, Expr* s, bool p, std::vector<Expr*> v)
|
||||
: ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
|
||||
variables_(std::move(v)),
|
||||
scalar_(s),
|
||||
@ -324,10 +319,10 @@ class MaxTerm : public ExprNode<MaxTerm> {
|
||||
return propagate_nans_;
|
||||
}
|
||||
|
||||
const Expr* scalar() const {
|
||||
Expr* scalar() const {
|
||||
return scalar_;
|
||||
}
|
||||
const std::vector<const Expr*>& variables() const {
|
||||
const std::vector<Expr*>& variables() const {
|
||||
return variables_;
|
||||
}
|
||||
HashProvider& hasher() const {
|
||||
@ -335,17 +330,17 @@ class MaxTerm : public ExprNode<MaxTerm> {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<const Expr*> variables_;
|
||||
const Expr* scalar_;
|
||||
std::vector<Expr*> variables_;
|
||||
Expr* scalar_;
|
||||
HashProvider& hasher_;
|
||||
bool propagate_nans_;
|
||||
|
||||
void addComponent() {}
|
||||
void addComponent(const Expr* e) {
|
||||
void addComponent(Expr* e) {
|
||||
variables_.push_back(e);
|
||||
}
|
||||
template <class... Es>
|
||||
void addComponent(const Expr* e, Es... es) {
|
||||
void addComponent(Expr* e, Es... es) {
|
||||
addComponent(e);
|
||||
addComponent(es...);
|
||||
}
|
||||
@ -358,7 +353,7 @@ class MinTerm : public ExprNode<MinTerm> {
|
||||
public:
|
||||
template <class... Args>
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
MinTerm(HashProvider& hasher, const Expr* s, bool p, Args... ts)
|
||||
MinTerm(HashProvider& hasher, Expr* s, bool p, Args... ts)
|
||||
: ExprNodeBase(s ? promoteTypesVar(s, ts...) : promoteTypesVar(ts...)),
|
||||
scalar_(s),
|
||||
hasher_(hasher),
|
||||
@ -368,11 +363,7 @@ class MinTerm : public ExprNode<MinTerm> {
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
MinTerm(
|
||||
HashProvider& hasher,
|
||||
const Expr* s,
|
||||
bool p,
|
||||
std::vector<const Expr*> v)
|
||||
MinTerm(HashProvider& hasher, Expr* s, bool p, std::vector<Expr*> v)
|
||||
: ExprNodeBase(s ? promoteTypesVec(s, v) : promoteTypesVec(v)),
|
||||
variables_(std::move(v)),
|
||||
scalar_(s),
|
||||
@ -385,10 +376,10 @@ class MinTerm : public ExprNode<MinTerm> {
|
||||
return propagate_nans_;
|
||||
}
|
||||
|
||||
const Expr* scalar() const {
|
||||
Expr* scalar() const {
|
||||
return scalar_;
|
||||
}
|
||||
const std::vector<const Expr*>& variables() const {
|
||||
const std::vector<Expr*>& variables() const {
|
||||
return variables_;
|
||||
}
|
||||
HashProvider& hasher() const {
|
||||
@ -396,17 +387,17 @@ class MinTerm : public ExprNode<MinTerm> {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<const Expr*> variables_;
|
||||
const Expr* scalar_;
|
||||
std::vector<Expr*> variables_;
|
||||
Expr* scalar_;
|
||||
HashProvider& hasher_;
|
||||
bool propagate_nans_;
|
||||
|
||||
void addComponent() {}
|
||||
void addComponent(const Expr* e) {
|
||||
void addComponent(Expr* e) {
|
||||
variables_.push_back(e);
|
||||
}
|
||||
template <class... Es>
|
||||
void addComponent(const Expr* e, Es... es) {
|
||||
void addComponent(Expr* e, Es... es) {
|
||||
addComponent(e);
|
||||
addComponent(es...);
|
||||
}
|
||||
@ -416,16 +407,15 @@ class MinTerm : public ExprNode<MinTerm> {
|
||||
};
|
||||
|
||||
// Context-sensitive IR simplification
|
||||
using VarBoundInfo =
|
||||
std::unordered_map<const Var*, std::pair<const Expr*, const Expr*>>;
|
||||
using VarBoundInfo = std::unordered_map<Var*, std::pair<Expr*, Expr*>>;
|
||||
class TORCH_API SimplifierUnderContext : public IRMutator {
|
||||
public:
|
||||
~SimplifierUnderContext() override = default;
|
||||
// Add boundary info for index variables in for-loops
|
||||
Stmt* mutate(const For* v) override;
|
||||
Stmt* mutate(For* v) override;
|
||||
|
||||
const Expr* mutate(const Div* v) override;
|
||||
const Expr* mutate(const Mod* v) override;
|
||||
Expr* mutate(Div* v) override;
|
||||
Expr* mutate(Mod* v) override;
|
||||
|
||||
protected:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
@ -438,14 +428,14 @@ class TORCH_API PolynomialBase : public IRMutator {
|
||||
public:
|
||||
~PolynomialBase() override = default;
|
||||
|
||||
Stmt* mutate(const Block* v) override;
|
||||
Stmt* mutate(Block* v) override;
|
||||
|
||||
Stmt* mutate(const Cond* v) override;
|
||||
Stmt* mutate(Cond* v) override;
|
||||
|
||||
Stmt* mutate(const For* v) override;
|
||||
Stmt* mutate(For* v) override;
|
||||
|
||||
// Trivially factorize terms by GCD of scalar components.
|
||||
const Term* factorizePolynomial(const Polynomial* poly);
|
||||
Term* factorizePolynomial(Polynomial* poly);
|
||||
|
||||
HashProvider& hasher() {
|
||||
return hasher_;
|
||||
@ -463,89 +453,89 @@ class TORCH_API PolynomialTransformer : public PolynomialBase {
|
||||
// Inserts term into the provided map, in the case of a hash collision
|
||||
// combines the term with the existing and updates the map.
|
||||
void addOrUpdateTerm(
|
||||
std::unordered_map<SimplifierHashType, const Term*>& varmap,
|
||||
const Term* term);
|
||||
std::unordered_map<SimplifierHashType, Term*>& varmap,
|
||||
Term* term);
|
||||
|
||||
// Add Polynomial expressions, combining Terms representing the same
|
||||
// variables.
|
||||
const Expr* addPolynomials(const Polynomial* lhs, const Polynomial* rhs);
|
||||
Expr* addPolynomials(Polynomial* lhs, Polynomial* rhs);
|
||||
|
||||
// Insert a new Term into the provided polynomial. If the new term has common
|
||||
// variables to an existing term it is combined.
|
||||
const Expr* insertTerm(const Polynomial* poly, const Term* term);
|
||||
Expr* insertTerm(Polynomial* poly, Term* term);
|
||||
|
||||
// Merge and simplify addition.
|
||||
const Expr* mutate(const Add* v) override;
|
||||
Expr* mutate(Add* v) override;
|
||||
|
||||
// Subtract one term from another, cancelling if necessary.
|
||||
const Expr* subTerms(const Term* lhs, const Term* rhs, bool negated);
|
||||
Expr* subTerms(Term* lhs, Term* rhs, bool negated);
|
||||
|
||||
// Subtract the RHS Polynomial from the LHS Polynomial, cancelling out where
|
||||
// possible.
|
||||
const Expr* subPolynomials(const Polynomial* lhs, const Polynomial* rhs);
|
||||
Expr* subPolynomials(Polynomial* lhs, Polynomial* rhs);
|
||||
|
||||
// Merge and simplify subtraction.
|
||||
const Expr* mutate(const Sub* v) override;
|
||||
Expr* mutate(Sub* v) override;
|
||||
|
||||
// Multiply two terms together, usually creating a new term with the variable
|
||||
// lists concatenated.
|
||||
const Term* mulTerms(const Term* lhs, const Term* rhs);
|
||||
Term* mulTerms(Term* lhs, Term* rhs);
|
||||
|
||||
// Multiply a Polynomial by a Term.
|
||||
const Expr* polyByTerm(const Polynomial* poly, const Term* term);
|
||||
Expr* polyByTerm(Polynomial* poly, Term* term);
|
||||
|
||||
// Match a rounding pattern and create a RoundOff if found.
|
||||
const Expr* isRoundOff(const Expr* lhs, const Expr* rhs);
|
||||
Expr* isRoundOff(Expr* lhs, Expr* rhs);
|
||||
|
||||
// Inserts a new component into a term, simplifying if possible.
|
||||
const Expr* insertIntoTerm(const Term* term, const Expr* expr);
|
||||
Expr* insertIntoTerm(Term* term, Expr* expr);
|
||||
|
||||
// Merge and simplify multiplication.
|
||||
const Expr* mutate(const Mul* v) override;
|
||||
Expr* mutate(Mul* v) override;
|
||||
|
||||
const Expr* mutate(const Div* v) override;
|
||||
Expr* mutate(Div* v) override;
|
||||
|
||||
const Expr* mutate(const Mod* v) override;
|
||||
Expr* mutate(Mod* v) override;
|
||||
|
||||
const Expr* mutate(const And* v) override {
|
||||
Expr* mutate(And* v) override {
|
||||
return mutateBinaryOp(v, this);
|
||||
}
|
||||
|
||||
const Expr* mutate(const Xor* v) override {
|
||||
Expr* mutate(Xor* v) override {
|
||||
return mutateBinaryOp(v, this);
|
||||
}
|
||||
|
||||
const Expr* mutate(const Lshift* v) override {
|
||||
Expr* mutate(Lshift* v) override {
|
||||
return mutateBinaryOp(v, this);
|
||||
}
|
||||
|
||||
const Expr* mutate(const Rshift* v) override {
|
||||
Expr* mutate(Rshift* v) override {
|
||||
return mutateBinaryOp(v, this);
|
||||
}
|
||||
|
||||
const Expr* mutate(const Max* v) override;
|
||||
Expr* mutate(Max* v) override;
|
||||
|
||||
const Expr* mutate(const Min* v) override;
|
||||
Expr* mutate(Min* v) override;
|
||||
|
||||
const Expr* mutate(const CompareSelect* v) override;
|
||||
Expr* mutate(CompareSelect* v) override;
|
||||
|
||||
const Expr* mutate(const Intrinsics* v) override;
|
||||
Expr* mutate(Intrinsics* v) override;
|
||||
|
||||
const Expr* mutate(const Cast* v) override;
|
||||
Expr* mutate(Cast* v) override;
|
||||
|
||||
const Expr* mutate(const IfThenElse* v) override;
|
||||
Expr* mutate(IfThenElse* v) override;
|
||||
|
||||
template <typename Op>
|
||||
static const Expr* mutateBinaryOp(
|
||||
const BinaryOpNode<Op>* v,
|
||||
static Expr* mutateBinaryOp(
|
||||
BinaryOpNode<Op>* v,
|
||||
IRMutator* mutator,
|
||||
bool option = false) {
|
||||
const Expr* lhs = v->lhs();
|
||||
const Expr* rhs = v->rhs();
|
||||
const Expr* lhs_new = lhs->accept_mutator(mutator);
|
||||
const Expr* rhs_new = rhs->accept_mutator(mutator);
|
||||
Expr* lhs = v->lhs();
|
||||
Expr* rhs = v->rhs();
|
||||
Expr* lhs_new = lhs->accept_mutator(mutator);
|
||||
Expr* rhs_new = rhs->accept_mutator(mutator);
|
||||
|
||||
const Expr* node = v;
|
||||
Expr* node = v;
|
||||
|
||||
if (lhs != lhs_new || rhs != rhs_new) {
|
||||
node = newBinaryOpOfType(v->expr_type(), lhs_new, rhs_new, option);
|
||||
@ -559,7 +549,7 @@ class TORCH_API PolynomialTransformer : public PolynomialBase {
|
||||
return evaluateOp(node);
|
||||
}
|
||||
|
||||
static const Expr* simplify(const Expr* e);
|
||||
static Expr* simplify(Expr* e);
|
||||
static ExprHandle simplify(const ExprHandle& e);
|
||||
static Stmt* simplify(Stmt* e);
|
||||
};
|
||||
@ -568,7 +558,7 @@ class TORCH_API PolynomialTransformer : public PolynomialBase {
|
||||
// Does some simple factorization and reordering.
|
||||
class TORCH_API TermExpander : public PolynomialBase {
|
||||
PolynomialTransformer* simplifier_;
|
||||
std::set<const Var*> eliminated_allocations_;
|
||||
std::set<Var*> eliminated_allocations_;
|
||||
|
||||
public:
|
||||
using PolynomialBase::mutate;
|
||||
@ -579,33 +569,33 @@ class TORCH_API TermExpander : public PolynomialBase {
|
||||
}
|
||||
|
||||
// Expand Terms out to a series of Muls.
|
||||
const Expr* mutate(const Term* v) override;
|
||||
Expr* mutate(Term* v) override;
|
||||
|
||||
// Expand Polynomials out to a series of Adds.
|
||||
const Expr* mutate(const Polynomial* v) override;
|
||||
Expr* mutate(Polynomial* v) override;
|
||||
|
||||
// Expand MaxTerms to a series of Max ops.
|
||||
const Expr* mutate(const MaxTerm* v) override;
|
||||
Expr* mutate(MaxTerm* v) override;
|
||||
|
||||
// Expand MinTerms to a series of Min ops.
|
||||
const Expr* mutate(const MinTerm* v) override;
|
||||
Expr* mutate(MinTerm* v) override;
|
||||
|
||||
// Expand RoundOff to it's component: Mul(Div(lhs, rhs), rhs).
|
||||
const Expr* mutate(const RoundOff* v) override;
|
||||
Expr* mutate(RoundOff* v) override;
|
||||
|
||||
// Eliminate zero length allocations.
|
||||
Stmt* mutate(const Allocate* v) override;
|
||||
Stmt* mutate(const Free* v) override;
|
||||
Stmt* mutate(Allocate* v) override;
|
||||
Stmt* mutate(Free* v) override;
|
||||
|
||||
// Override to enable condition fusing.
|
||||
Block* fuseConditions(Block* v);
|
||||
Stmt* fuseSyncThreads(Block* block);
|
||||
Stmt* mutate(const Block* v) override;
|
||||
Stmt* mutate(Block* v) override;
|
||||
};
|
||||
|
||||
class TORCH_API IRSimplifier {
|
||||
public:
|
||||
static const Expr* simplify(const Expr* e) {
|
||||
static Expr* simplify(Expr* e) {
|
||||
SimplifierUnderContext ctxsimplifier;
|
||||
e = e->accept_mutator(&ctxsimplifier);
|
||||
|
||||
@ -649,9 +639,9 @@ class TORCH_API IRSimplifier {
|
||||
};
|
||||
|
||||
// Flattens the buf and performs the simplifier on the flattened dims.
|
||||
const Expr* buf_flat_size(const Buf* v);
|
||||
Expr* buf_flat_size(Buf* v);
|
||||
// Returns true if expressions A and B can be simplified to an equal expression.
|
||||
TORCH_API bool exprEquals(const Expr* A, const Expr* B);
|
||||
TORCH_API bool exprEquals(Expr* A, Expr* B);
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
|
@ -19,39 +19,39 @@ void verifyBitwiseOp(const BitwiseOpNode<Op>* v, IRVerifier* verifier) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const And* v) {
|
||||
void IRVerifier::visit(And* v) {
|
||||
verifyBitwiseOp(v, this);
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Or* v) {
|
||||
void IRVerifier::visit(Or* v) {
|
||||
verifyBitwiseOp(v, this);
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Xor* v) {
|
||||
void IRVerifier::visit(Xor* v) {
|
||||
verifyBitwiseOp(v, this);
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Lshift* v) {
|
||||
void IRVerifier::visit(Lshift* v) {
|
||||
verifyBitwiseOp(v, this);
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Rshift* v) {
|
||||
void IRVerifier::visit(Rshift* v) {
|
||||
verifyBitwiseOp(v, this);
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Mod* v) {
|
||||
void IRVerifier::visit(Mod* v) {
|
||||
if (!v->dtype().is_integral() && !v->dtype().is_floating_point()) {
|
||||
throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype()));
|
||||
}
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const CompareSelect* v) {
|
||||
void IRVerifier::visit(CompareSelect* v) {
|
||||
if (v->ret_val1()->dtype() != v->ret_val2()->dtype()) {
|
||||
throw malformed_ir("bad dtype in CompareSelect");
|
||||
}
|
||||
@ -61,15 +61,15 @@ void IRVerifier::visit(const CompareSelect* v) {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Ramp* v) {
|
||||
void IRVerifier::visit(Ramp* v) {
|
||||
if (v->stride()->dtype() != v->base()->dtype()) {
|
||||
throw malformed_ir("Bad stride in Ramp");
|
||||
}
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Load* v) {
|
||||
const auto indices = v->indices();
|
||||
void IRVerifier::visit(Load* v) {
|
||||
auto indices = v->indices();
|
||||
if (indices.size() > 0 && v->buf()->base_handle()->dtype() != kHandle) {
|
||||
throw malformed_ir(
|
||||
"Load base handle dtype must be Handle", v->buf()->base_handle());
|
||||
@ -94,7 +94,7 @@ void IRVerifier::visit(const Load* v) {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const IfThenElse* v) {
|
||||
void IRVerifier::visit(IfThenElse* v) {
|
||||
if (!v->condition()->dtype().is_integral()) {
|
||||
throw unsupported_dtype();
|
||||
}
|
||||
@ -107,13 +107,13 @@ void IRVerifier::visit(const IfThenElse* v) {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Intrinsics* v) {
|
||||
void IRVerifier::visit(Intrinsics* v) {
|
||||
// TODO: add a check for OpArgCount and op_type
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Store* v) {
|
||||
const auto indices = v->indices();
|
||||
void IRVerifier::visit(Store* v) {
|
||||
auto indices = v->indices();
|
||||
if (indices.size() > 0 && v->buf()->base_handle()->dtype() != kHandle) {
|
||||
throw malformed_ir(
|
||||
"Store base handle dtype must be Handle", v->buf()->base_handle());
|
||||
@ -141,7 +141,7 @@ void IRVerifier::visit(const Store* v) {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const For* v) {
|
||||
void IRVerifier::visit(For* v) {
|
||||
if (!v->var()) {
|
||||
throw malformed_ir("nullptr Var in For loop");
|
||||
} else if (!v->start()) {
|
||||
@ -154,7 +154,7 @@ void IRVerifier::visit(const For* v) {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const Block* v) {
|
||||
void IRVerifier::visit(Block* v) {
|
||||
for (Stmt* s : v->stmts()) {
|
||||
if (s->get_parent() != v) {
|
||||
throw malformed_ir("Broken child-parent link inside a Block");
|
||||
@ -163,7 +163,7 @@ void IRVerifier::visit(const Block* v) {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
void IRVerifier::visit(const ExternalCall* v) {
|
||||
void IRVerifier::visit(ExternalCall* v) {
|
||||
IRVisitor::visit(v);
|
||||
}
|
||||
|
||||
@ -172,7 +172,7 @@ void verify(Stmt* s) {
|
||||
s->accept(&verifier);
|
||||
}
|
||||
|
||||
void verify(const Expr* e) {
|
||||
void verify(Expr* e) {
|
||||
IRVerifier verifier;
|
||||
e->accept(&verifier);
|
||||
}
|
||||
|
@ -32,26 +32,26 @@ class TORCH_API IRVerifier : public IRVisitor {
|
||||
public:
|
||||
IRVerifier() = default;
|
||||
|
||||
void visit(const Mod* v) override;
|
||||
void visit(const And* v) override;
|
||||
void visit(const Or* v) override;
|
||||
void visit(const Xor* v) override;
|
||||
void visit(const Lshift* v) override;
|
||||
void visit(const Rshift* v) override;
|
||||
void visit(const CompareSelect* v) override;
|
||||
void visit(const Ramp* v) override;
|
||||
void visit(const Load* v) override;
|
||||
void visit(const IfThenElse* v) override;
|
||||
void visit(const Intrinsics* v) override;
|
||||
void visit(Mod* v) override;
|
||||
void visit(And* v) override;
|
||||
void visit(Or* v) override;
|
||||
void visit(Xor* v) override;
|
||||
void visit(Lshift* v) override;
|
||||
void visit(Rshift* v) override;
|
||||
void visit(CompareSelect* v) override;
|
||||
void visit(Ramp* v) override;
|
||||
void visit(Load* v) override;
|
||||
void visit(IfThenElse* v) override;
|
||||
void visit(Intrinsics* v) override;
|
||||
|
||||
void visit(const ExternalCall* v) override;
|
||||
void visit(const Store* v) override;
|
||||
void visit(const For* v) override;
|
||||
void visit(const Block* v) override;
|
||||
void visit(ExternalCall* v) override;
|
||||
void visit(Store* v) override;
|
||||
void visit(For* v) override;
|
||||
void visit(Block* v) override;
|
||||
};
|
||||
|
||||
TORCH_API void verify(Stmt*);
|
||||
TORCH_API void verify(const Expr*);
|
||||
TORCH_API void verify(Expr*);
|
||||
TORCH_API void verify(ExprHandle);
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -12,60 +12,60 @@ namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
template <typename Op>
|
||||
static void visit_binary_op(const BinaryOpNode<Op>* v, IRVisitor* visitor) {
|
||||
static void visit_binary_op(BinaryOpNode<Op>* v, IRVisitor* visitor) {
|
||||
v->lhs()->accept(visitor);
|
||||
v->rhs()->accept(visitor);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Add* v) {
|
||||
void IRVisitor::visit(Add* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Sub* v) {
|
||||
void IRVisitor::visit(Sub* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Mul* v) {
|
||||
void IRVisitor::visit(Mul* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Div* v) {
|
||||
void IRVisitor::visit(Div* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Mod* v) {
|
||||
void IRVisitor::visit(Mod* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Max* v) {
|
||||
void IRVisitor::visit(Max* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Min* v) {
|
||||
void IRVisitor::visit(Min* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const And* v) {
|
||||
void IRVisitor::visit(And* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Or* v) {
|
||||
void IRVisitor::visit(Or* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Xor* v) {
|
||||
void IRVisitor::visit(Xor* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Lshift* v) {
|
||||
void IRVisitor::visit(Lshift* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Rshift* v) {
|
||||
void IRVisitor::visit(Rshift* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const CompareSelect* v) {
|
||||
void IRVisitor::visit(CompareSelect* v) {
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
v->ret_val1()->accept(this);
|
||||
@ -78,65 +78,65 @@ void IRVisitor::visit(const CompareSelect* v) {
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT);
|
||||
#undef IMM_VISIT
|
||||
|
||||
void IRVisitor::visit(const Cast* v) {
|
||||
void IRVisitor::visit(Cast* v) {
|
||||
v->src_value()->accept(this);
|
||||
}
|
||||
void IRVisitor::visit(const BitCast* v) {
|
||||
void IRVisitor::visit(BitCast* v) {
|
||||
v->src_value()->accept(this);
|
||||
}
|
||||
void IRVisitor::visit(const Var* v) {}
|
||||
void IRVisitor::visit(Var* v) {}
|
||||
|
||||
void IRVisitor::visit(const Ramp* v) {
|
||||
void IRVisitor::visit(Ramp* v) {
|
||||
v->base()->accept(this);
|
||||
v->stride()->accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Load* v) {
|
||||
void IRVisitor::visit(Load* v) {
|
||||
v->buf()->accept(this);
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
ind->accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Buf* v) {
|
||||
void IRVisitor::visit(Buf* v) {
|
||||
v->base_handle()->accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Store* v) {
|
||||
void IRVisitor::visit(Store* v) {
|
||||
v->buf()->accept(this);
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
ind->accept(this);
|
||||
}
|
||||
v->value()->accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const AtomicAdd* v) {
|
||||
void IRVisitor::visit(AtomicAdd* v) {
|
||||
v->buf()->accept(this);
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
ind->accept(this);
|
||||
}
|
||||
v->value()->accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const SyncThreads* v) {}
|
||||
void IRVisitor::visit(SyncThreads* v) {}
|
||||
|
||||
void IRVisitor::visit(const ExternalCall* v) {
|
||||
void IRVisitor::visit(ExternalCall* v) {
|
||||
v->buf()->accept(this);
|
||||
for (const Buf* buf_arg : v->buf_args()) {
|
||||
for (Buf* buf_arg : v->buf_args()) {
|
||||
buf_arg->accept(this);
|
||||
}
|
||||
for (const Expr* arg : v->args()) {
|
||||
for (Expr* arg : v->args()) {
|
||||
arg->accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Block* v) {
|
||||
void IRVisitor::visit(Block* v) {
|
||||
for (Stmt* s : *v) {
|
||||
s->accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const For* v) {
|
||||
void IRVisitor::visit(For* v) {
|
||||
v->var()->accept(this);
|
||||
v->start()->accept(this);
|
||||
v->stop()->accept(this);
|
||||
@ -145,41 +145,41 @@ void IRVisitor::visit(const For* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Broadcast* v) {
|
||||
void IRVisitor::visit(Broadcast* v) {
|
||||
v->value()->accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const IfThenElse* v) {
|
||||
void IRVisitor::visit(IfThenElse* v) {
|
||||
v->condition()->accept(this);
|
||||
v->true_value()->accept(this);
|
||||
v->false_value()->accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Intrinsics* v) {
|
||||
for (const auto i : c10::irange(v->nparams())) {
|
||||
void IRVisitor::visit(Intrinsics* v) {
|
||||
for (auto i : c10::irange(v->nparams())) {
|
||||
v->param(i)->accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Allocate* v) {
|
||||
void IRVisitor::visit(Allocate* v) {
|
||||
v->buffer_var()->accept(this);
|
||||
std::vector<const Expr*> dims = v->dims();
|
||||
for (const Expr* dim : dims) {
|
||||
std::vector<Expr*> dims = v->dims();
|
||||
for (Expr* dim : dims) {
|
||||
dim->accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Free* v) {
|
||||
void IRVisitor::visit(Free* v) {
|
||||
v->buffer_var()->accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Let* v) {
|
||||
void IRVisitor::visit(Let* v) {
|
||||
v->var()->accept(this);
|
||||
v->value()->accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Cond* v) {
|
||||
const Expr* condition = v->condition();
|
||||
void IRVisitor::visit(Cond* v) {
|
||||
Expr* condition = v->condition();
|
||||
Stmt* true_stmt = v->true_stmt();
|
||||
Stmt* false_stmt = v->false_stmt();
|
||||
condition->accept(this);
|
||||
@ -191,26 +191,26 @@ void IRVisitor::visit(const Cond* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Term* v) {
|
||||
void IRVisitor::visit(Term* v) {
|
||||
v->scalar()->accept(this);
|
||||
for (auto* t : v->variables()) {
|
||||
t->accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Polynomial* v) {
|
||||
void IRVisitor::visit(Polynomial* v) {
|
||||
v->scalar()->accept(this);
|
||||
for (auto* t : v->variables()) {
|
||||
t->accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const RoundOff* v) {
|
||||
void IRVisitor::visit(RoundOff* v) {
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const MaxTerm* v) {
|
||||
void IRVisitor::visit(MaxTerm* v) {
|
||||
if (v->scalar()) {
|
||||
v->scalar()->accept(this);
|
||||
}
|
||||
@ -219,7 +219,7 @@ void IRVisitor::visit(const MaxTerm* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const MinTerm* v) {
|
||||
void IRVisitor::visit(MinTerm* v) {
|
||||
if (v->scalar()) {
|
||||
v->scalar()->accept(this);
|
||||
}
|
||||
@ -228,7 +228,7 @@ void IRVisitor::visit(const MinTerm* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const ReduceOp* v) {
|
||||
void IRVisitor::visit(ReduceOp* v) {
|
||||
v->body()->accept(this);
|
||||
|
||||
for (auto* r : v->reduce_args()) {
|
||||
|
@ -54,50 +54,50 @@ class ExternalCall;
|
||||
class TORCH_API IRVisitor {
|
||||
public:
|
||||
virtual ~IRVisitor() = default;
|
||||
virtual void visit(const Add* v);
|
||||
virtual void visit(const Sub* v);
|
||||
virtual void visit(const Mul* v);
|
||||
virtual void visit(const Div* v);
|
||||
virtual void visit(const Mod* v);
|
||||
virtual void visit(const Max* v);
|
||||
virtual void visit(const Min* v);
|
||||
virtual void visit(const And* v);
|
||||
virtual void visit(const Or* v);
|
||||
virtual void visit(const Xor* v);
|
||||
virtual void visit(const Lshift* v);
|
||||
virtual void visit(const Rshift* v);
|
||||
virtual void visit(const CompareSelect* v);
|
||||
virtual void visit(Add* v);
|
||||
virtual void visit(Sub* v);
|
||||
virtual void visit(Mul* v);
|
||||
virtual void visit(Div* v);
|
||||
virtual void visit(Mod* v);
|
||||
virtual void visit(Max* v);
|
||||
virtual void visit(Min* v);
|
||||
virtual void visit(And* v);
|
||||
virtual void visit(Or* v);
|
||||
virtual void visit(Xor* v);
|
||||
virtual void visit(Lshift* v);
|
||||
virtual void visit(Rshift* v);
|
||||
virtual void visit(CompareSelect* v);
|
||||
|
||||
#define IMM_PRINT_VISIT(Type, Name) virtual void visit(const Name##Imm* v);
|
||||
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT)
|
||||
#undef IMM_PRINT_VISIT
|
||||
|
||||
virtual void visit(const Cast* v);
|
||||
virtual void visit(const BitCast* v);
|
||||
virtual void visit(const Var* v);
|
||||
virtual void visit(const Buf* v);
|
||||
virtual void visit(const Ramp* v);
|
||||
virtual void visit(const Load* v);
|
||||
virtual void visit(const For* v);
|
||||
virtual void visit(const Block* v);
|
||||
virtual void visit(const Store* v);
|
||||
virtual void visit(const Broadcast* v);
|
||||
virtual void visit(const IfThenElse* v);
|
||||
virtual void visit(const Intrinsics* v);
|
||||
virtual void visit(const Allocate* v);
|
||||
virtual void visit(const Free* v);
|
||||
virtual void visit(const Let* v);
|
||||
virtual void visit(const Cond* v);
|
||||
virtual void visit(const Term* v);
|
||||
virtual void visit(const Polynomial* v);
|
||||
virtual void visit(const RoundOff* v);
|
||||
virtual void visit(const MaxTerm* v);
|
||||
virtual void visit(const MinTerm* v);
|
||||
virtual void visit(const ReduceOp* v);
|
||||
virtual void visit(const AtomicAdd* v);
|
||||
virtual void visit(const SyncThreads* v);
|
||||
virtual void visit(const ExternalCall* v);
|
||||
virtual void visit(Cast* v);
|
||||
virtual void visit(BitCast* v);
|
||||
virtual void visit(Var* v);
|
||||
virtual void visit(Buf* v);
|
||||
virtual void visit(Ramp* v);
|
||||
virtual void visit(Load* v);
|
||||
virtual void visit(For* v);
|
||||
virtual void visit(Block* v);
|
||||
virtual void visit(Store* v);
|
||||
virtual void visit(Broadcast* v);
|
||||
virtual void visit(IfThenElse* v);
|
||||
virtual void visit(Intrinsics* v);
|
||||
virtual void visit(Allocate* v);
|
||||
virtual void visit(Free* v);
|
||||
virtual void visit(Let* v);
|
||||
virtual void visit(Cond* v);
|
||||
virtual void visit(Term* v);
|
||||
virtual void visit(Polynomial* v);
|
||||
virtual void visit(RoundOff* v);
|
||||
virtual void visit(MaxTerm* v);
|
||||
virtual void visit(MinTerm* v);
|
||||
virtual void visit(ReduceOp* v);
|
||||
virtual void visit(AtomicAdd* v);
|
||||
virtual void visit(SyncThreads* v);
|
||||
virtual void visit(ExternalCall* v);
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -201,7 +201,7 @@ c10::optional<TensorInfo> getTensorInfoJit(torch::jit::Value* v) {
|
||||
c10::optional<TensorInfo> getTensorInfo(BufHandle b) {
|
||||
std::vector<int64_t> dims;
|
||||
for (auto dim : b.dims()) {
|
||||
auto val = dynamic_cast<const IntImm*>(dim.node());
|
||||
auto val = dynamic_cast<IntImm*>(dim.node());
|
||||
if (!val) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
@ -460,7 +460,7 @@ void promoteInputs(std::vector<ExprHandle>& inputs, const int typeConstraints) {
|
||||
|
||||
// Find the highest type among the inputs.
|
||||
ScalarType highType = inputs[0].dtype().scalar_type();
|
||||
for (const auto input : inputs) {
|
||||
for (auto input : inputs) {
|
||||
highType = promoteTypes(highType, input.dtype().scalar_type());
|
||||
}
|
||||
|
||||
@ -503,20 +503,20 @@ ExprHandle demoteOutput(
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
static at::ScalarType tensorType(const Buf* b) {
|
||||
static at::ScalarType tensorType(Buf* b) {
|
||||
return static_cast<at::ScalarType>(b->dtype().scalar_type());
|
||||
}
|
||||
|
||||
std::vector<int64_t> bufferSizes(const Buf* b) {
|
||||
std::vector<int64_t> bufferSizes(Buf* b) {
|
||||
std::vector<int64_t> sizes;
|
||||
for (size_t i = 0; i < b->ndim(); i++) {
|
||||
sizes.push_back(dynamic_cast<const IntImm*>(b->dim(i))->value());
|
||||
sizes.push_back(dynamic_cast<IntImm*>(b->dim(i))->value());
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
ExprHandle TensorExprKernel::chunk(
|
||||
const Buf* b,
|
||||
Buf* b,
|
||||
size_t chunkIdx,
|
||||
int64_t dim,
|
||||
int64_t chunks,
|
||||
@ -539,7 +539,7 @@ ExprHandle TensorExprKernel::chunk(
|
||||
|
||||
ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) {
|
||||
if (v->node()->kind() == prim::Constant) {
|
||||
const auto val = toIValue(v).value();
|
||||
auto val = toIValue(v).value();
|
||||
if (val.isDouble()) {
|
||||
return DoubleImm::make(val.toDouble());
|
||||
} else if (val.isInt()) {
|
||||
@ -598,7 +598,7 @@ ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
|
||||
throw unsupported_dtype();
|
||||
}
|
||||
if (v->node()->kind() == prim::Constant) {
|
||||
const auto val = toIValue(v).value();
|
||||
auto val = toIValue(v).value();
|
||||
if (val.isDouble()) {
|
||||
return val.toDouble();
|
||||
} else if (val.isInt()) {
|
||||
@ -626,7 +626,7 @@ ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
|
||||
std::vector<ExprHandle> TensorExprKernel::sizesFromVaryingShape(
|
||||
const c10::VaryingShape<int64_t>& shape) {
|
||||
std::vector<ExprHandle> dims;
|
||||
for (const auto i : c10::irange(*shape.size())) {
|
||||
for (auto i : c10::irange(*shape.size())) {
|
||||
dims.push_back(IntImm::make(*shape[i]));
|
||||
}
|
||||
return dims;
|
||||
@ -728,7 +728,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
|
||||
case aten::remainder:
|
||||
case aten::atan2: {
|
||||
std::vector<std::vector<ExprHandle>> shapes;
|
||||
for (const auto idx : c10::irange(2)) {
|
||||
for (auto idx : c10::irange(2)) {
|
||||
torch::jit::Value* inp = v->node()->input(idx);
|
||||
shapes.push_back(sizesForValue(inp));
|
||||
}
|
||||
@ -739,7 +739,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
|
||||
case aten::threshold:
|
||||
case aten::where: {
|
||||
std::vector<std::vector<ExprHandle>> shapes;
|
||||
for (const auto idx : c10::irange(3)) {
|
||||
for (auto idx : c10::irange(3)) {
|
||||
torch::jit::Value* inp = v->node()->input(idx);
|
||||
shapes.push_back(sizesForValue(inp));
|
||||
}
|
||||
@ -748,7 +748,7 @@ std::vector<ExprHandle> TensorExprKernel::inferSizesForValue(
|
||||
|
||||
case aten::addcmul: {
|
||||
std::vector<std::vector<ExprHandle>> shapes;
|
||||
for (const auto idx : c10::irange(4)) {
|
||||
for (auto idx : c10::irange(4)) {
|
||||
torch::jit::Value* inp = v->node()->input(idx);
|
||||
shapes.push_back(sizesForValue(inp));
|
||||
}
|
||||
@ -1129,7 +1129,7 @@ std::pair<ScalarType, std::vector<BufHandle>> processCatList(
|
||||
nonEmptyInputs.push_back(buf);
|
||||
}
|
||||
ScalarType highType = bufInputs[0].dtype().scalar_type();
|
||||
for (const auto input : bufInputs) {
|
||||
for (auto input : bufInputs) {
|
||||
auto maybe_dtype = input.dtype().scalar_type();
|
||||
highType = promoteTypes(highType, maybe_dtype);
|
||||
}
|
||||
@ -1172,11 +1172,11 @@ Tensor* computeCatWoConditionals(
|
||||
|
||||
auto gen_code_for_input = [&](const BufHandle& inp,
|
||||
size_t inp_pos,
|
||||
const Expr* concat_dim_size,
|
||||
Expr* concat_dim_size,
|
||||
const std::vector<ExprHandle>& dims) {
|
||||
std::vector<Var*> for_vars(dims.size());
|
||||
std::vector<const Expr*> load_indices(dims.size());
|
||||
std::vector<const Expr*> store_indices(dims.size());
|
||||
std::vector<Expr*> load_indices(dims.size());
|
||||
std::vector<Expr*> store_indices(dims.size());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
for_vars[i] = new Var(
|
||||
"i" + c10::to_string(inp_pos) + "_" + c10::to_string(i), kInt);
|
||||
@ -1256,8 +1256,7 @@ Tensor* computeCat(
|
||||
ExprHandle load = promoteToDtype(
|
||||
tensorOrConstant(nonEmptyInputs[0], newAxes), highType);
|
||||
size_t offset =
|
||||
dynamic_cast<const IntImm*>(nonEmptyInputs[0].node()->dim(dim))
|
||||
->value();
|
||||
dynamic_cast<IntImm*>(nonEmptyInputs[0].node()->dim(dim))->value();
|
||||
newAxes[dim] = newAxes[dim] - IntImm::make(offset);
|
||||
|
||||
for (size_t ii = 1; ii < nonEmptyInputs.size(); ++ii) {
|
||||
@ -1267,8 +1266,7 @@ Tensor* computeCat(
|
||||
load,
|
||||
promoteToDtype(tensorOrConstant(input, newAxes), highType));
|
||||
|
||||
offset +=
|
||||
dynamic_cast<const IntImm*>(input.node()->dim(dim))->value();
|
||||
offset += dynamic_cast<IntImm*>(input.node()->dim(dim))->value();
|
||||
newAxes[dim] = axes[dim] - IntImm::make(offset);
|
||||
}
|
||||
|
||||
@ -2317,7 +2315,7 @@ Tensor* tensorexpr::computeOperandValue(
|
||||
*/
|
||||
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
|
||||
ExprHandle cur_stride = 1;
|
||||
std::vector<const Expr*> dims, indices;
|
||||
std::vector<Expr*> dims, indices;
|
||||
for (size_t idx = 0; idx < view_dims.size(); idx++) {
|
||||
dims.push_back(new IntImm(view_dims[idx]));
|
||||
indices.push_back(axes[idx].node());
|
||||
@ -2431,7 +2429,7 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
||||
}
|
||||
|
||||
// Return the (lower, upper) loop bounds if they are constants, else nullopt.
|
||||
c10::optional<std::pair<int64_t, int64_t>> loopBounds(const For* loop) {
|
||||
c10::optional<std::pair<int64_t, int64_t>> loopBounds(For* loop) {
|
||||
auto start = IRSimplifier::simplify(loop->start());
|
||||
auto stop = IRSimplifier::simplify(loop->stop());
|
||||
if (!start->isConstant() || !stop->isConstant()) {
|
||||
@ -2803,7 +2801,7 @@ bool denseAndNonOverlapping(
|
||||
Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
|
||||
const TensorTypePtr& tt = v->type()->expect<TensorType>();
|
||||
TORCH_INTERNAL_ASSERT(bufs_.count(v));
|
||||
const Buf* buf = bufs_.at(v);
|
||||
Buf* buf = bufs_.at(v);
|
||||
|
||||
// No shape info is present in the graph
|
||||
if (!tt->sizes().concrete_sizes()) {
|
||||
@ -2813,7 +2811,7 @@ Tensor* TensorExprKernel::convertOutputToCorrectStrides(torch::jit::Value* v) {
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(tt->sizes().concrete_sizes());
|
||||
const auto sizes = *tt->sizes().concrete_sizes();
|
||||
auto sizes = *tt->sizes().concrete_sizes();
|
||||
std::vector<int64_t> default_strides = TensorType::contiguousStridesOf(sizes);
|
||||
if (!tt->strides().concrete_sizes()) {
|
||||
return new Tensor(buf, nullptr);
|
||||
@ -2887,14 +2885,14 @@ void TensorExprKernel::bindConstant(const torch::jit::Value* v) {
|
||||
auto const_tensor = toIValue(v)->toTensor();
|
||||
|
||||
const auto& tt = v->type()->expect<TensorType>();
|
||||
const auto sizes = *tt->sizes().concrete_sizes();
|
||||
auto sizes = *tt->sizes().concrete_sizes();
|
||||
std::vector<ExprHandle> te_sizes;
|
||||
te_sizes.reserve(sizes.size());
|
||||
for (auto s : sizes) {
|
||||
te_sizes.push_back(IntImm::make(s));
|
||||
}
|
||||
|
||||
const Buf* buf = new Buf(
|
||||
Buf* buf = new Buf(
|
||||
"const_" + v->debugName(),
|
||||
ExprHandleVectorToExprVector(te_sizes),
|
||||
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
|
||||
@ -2951,7 +2949,7 @@ void TensorExprKernel::compile() {
|
||||
}
|
||||
|
||||
// Move output operands from `bufs_` to `bufOutputs_`
|
||||
for (const auto& output : graph_->outputs()) {
|
||||
for (auto& output : graph_->outputs()) {
|
||||
if (!bufs_.count(output)) {
|
||||
throw malformed_input("cannot find output Tensor");
|
||||
}
|
||||
@ -3046,7 +3044,7 @@ std::vector<CodeGen::CallArg> TensorExprKernel::prepareRunArgs(
|
||||
std::vector<CodeGen::CallArg> runArgs;
|
||||
runArgs.reserve(inputs.size() + bufOutputs_.size());
|
||||
|
||||
for (const auto& input : inputs) {
|
||||
for (auto& input : inputs) {
|
||||
if (input.isInt()) {
|
||||
runArgs.emplace_back(input.toInt());
|
||||
} else if (input.isDouble()) {
|
||||
|
@ -19,7 +19,7 @@ template <typename T>
|
||||
inline std::vector<int64_t> bufferSizes(const T& t) {
|
||||
std::vector<int64_t> sizes;
|
||||
for (size_t i = 0; i < t->ndim(); i++) {
|
||||
sizes.push_back(dynamic_cast<const IntImm*>(t->dim(i))->value());
|
||||
sizes.push_back(dynamic_cast<IntImm*>(t->dim(i))->value());
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
@ -105,7 +105,7 @@ inline std::string getArgValueName(const ArgValue& a) {
|
||||
template <class T>
|
||||
std::vector<T> convertVecArgValue(const std::vector<ArgValue>& v) {
|
||||
std::vector<T> res;
|
||||
for (const auto& x : v) {
|
||||
for (auto& x : v) {
|
||||
auto val = c10::get_if<T>(&x);
|
||||
if (val) {
|
||||
res.push_back(*val);
|
||||
@ -132,7 +132,7 @@ TORCH_API Tensor* computeOperandValue(
|
||||
|
||||
class TORCH_API TensorExprKernel {
|
||||
struct ConstantDescr {
|
||||
const Buf* buf;
|
||||
Buf* buf;
|
||||
void* ptr;
|
||||
};
|
||||
|
||||
@ -196,7 +196,7 @@ class TORCH_API TensorExprKernel {
|
||||
std::vector<std::vector<ExprHandle>> shapes);
|
||||
|
||||
ExprHandle chunk(
|
||||
const Buf* b,
|
||||
Buf* b,
|
||||
size_t chunkIdx,
|
||||
int64_t dim,
|
||||
int64_t chunks,
|
||||
@ -260,8 +260,8 @@ class TORCH_API TensorExprKernel {
|
||||
std::vector<std::vector<int64_t>> tensorOutputSizes_;
|
||||
std::vector<std::vector<int64_t>> tensorOutputStrides_;
|
||||
std::vector<UnpackedTensorOptions> tensorOutputTensorOptions_;
|
||||
std::unordered_set<const Buf*> bufOutputs_;
|
||||
std::unordered_map<const torch::jit::Value*, const Buf*> bufs_;
|
||||
std::unordered_set<Buf*> bufOutputs_;
|
||||
std::unordered_map<const torch::jit::Value*, Buf*> bufs_;
|
||||
std::unordered_map<const torch::jit::Value*, VarHandle> scalars_;
|
||||
std::unordered_map<const torch::jit::Value*, std::string> input_name_map_;
|
||||
std::unique_ptr<CodeGen> codegen_;
|
||||
|
@ -169,8 +169,8 @@ class LLVMCodeGenImpl : public IRVisitor {
|
||||
|
||||
std::unordered_map<const Var*, int> varToArg_;
|
||||
std::unordered_map<const Var*, llvm::Value*> varToVal_;
|
||||
std::unordered_map<const Block*, std::vector<const Var*>> scopeToVar_;
|
||||
const Block* scope_;
|
||||
std::unordered_map<Block*, std::vector<Var*>> scopeToVar_;
|
||||
Block* scope_;
|
||||
|
||||
std::string llvmCode_;
|
||||
std::string asmCode_;
|
||||
@ -195,13 +195,13 @@ class LLVMCodeGenImpl : public IRVisitor {
|
||||
Arity arity,
|
||||
int lanes);
|
||||
|
||||
llvm::Value* varToValue(const Var* var);
|
||||
llvm::Value* varToValue(Var* var);
|
||||
void replaceVarMapping(
|
||||
const std::vector<const Var*>& vars,
|
||||
const std::vector<Var*>& vars,
|
||||
const std::vector<llvm::Value*>& vals);
|
||||
llvm::Value* packFuncArgs(const std::vector<llvm::Value*>& func_args);
|
||||
std::vector<llvm::Value*> unpackFuncArgs(llvm::Value* packed, int arg_count);
|
||||
void processParallelFor(const For* v);
|
||||
void processParallelFor(For* v);
|
||||
|
||||
public:
|
||||
LLVMCodeGenImpl(
|
||||
@ -216,42 +216,42 @@ class LLVMCodeGenImpl : public IRVisitor {
|
||||
|
||||
llvm::JITTargetAddress getKernelAddress() const;
|
||||
|
||||
void visit(const Add* v) override;
|
||||
void visit(const Sub* v) override;
|
||||
void visit(const Mul* v) override;
|
||||
void visit(const Div* v) override;
|
||||
void visit(const Mod* v) override;
|
||||
void visit(const Max* v) override;
|
||||
void visit(const Min* v) override;
|
||||
void visit(const And* v) override;
|
||||
void visit(const Or* v) override;
|
||||
void visit(const Xor* v) override;
|
||||
void visit(const Lshift* v) override;
|
||||
void visit(const Rshift* v) override;
|
||||
void visit(const CompareSelect* v) override;
|
||||
void visit(Add* v) override;
|
||||
void visit(Sub* v) override;
|
||||
void visit(Mul* v) override;
|
||||
void visit(Div* v) override;
|
||||
void visit(Mod* v) override;
|
||||
void visit(Max* v) override;
|
||||
void visit(Min* v) override;
|
||||
void visit(And* v) override;
|
||||
void visit(Or* v) override;
|
||||
void visit(Xor* v) override;
|
||||
void visit(Lshift* v) override;
|
||||
void visit(Rshift* v) override;
|
||||
void visit(CompareSelect* v) override;
|
||||
|
||||
#define IMM_VISIT_DECLARE(_1, Name) void visit(const Name##Imm* v) override;
|
||||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT_DECLARE);
|
||||
#undef IMM_VISIT_DECLARE
|
||||
|
||||
void visit(const Cast* v) override;
|
||||
void visit(const BitCast* v) override;
|
||||
void visit(const Var* v) override;
|
||||
void visit(const Ramp* v) override;
|
||||
void visit(const Load* v) override;
|
||||
void visit(const For* v) override;
|
||||
void visit(const Block* v) override;
|
||||
void visit(const Store* v) override;
|
||||
void visit(const Broadcast* v) override;
|
||||
void visit(const IfThenElse* v) override;
|
||||
void visit(const Intrinsics* v) override;
|
||||
void visit(const Allocate* v) override;
|
||||
void visit(const Free* v) override;
|
||||
void visit(const Let* v) override;
|
||||
void visit(const Cond* v) override;
|
||||
void visit(const ExternalCall* v) override;
|
||||
void visit(Cast* v) override;
|
||||
void visit(BitCast* v) override;
|
||||
void visit(Var* v) override;
|
||||
void visit(Ramp* v) override;
|
||||
void visit(Load* v) override;
|
||||
void visit(For* v) override;
|
||||
void visit(Block* v) override;
|
||||
void visit(Store* v) override;
|
||||
void visit(Broadcast* v) override;
|
||||
void visit(IfThenElse* v) override;
|
||||
void visit(Intrinsics* v) override;
|
||||
void visit(Allocate* v) override;
|
||||
void visit(Free* v) override;
|
||||
void visit(Let* v) override;
|
||||
void visit(Cond* v) override;
|
||||
void visit(ExternalCall* v) override;
|
||||
|
||||
void emitIsNan(const Intrinsics* v);
|
||||
void emitIsNan(Intrinsics* v);
|
||||
|
||||
llvm::Value* emitUnmaskedLoad(llvm::Value* addr, llvm::Value* idx);
|
||||
llvm::Value* emitMaskedLoad(
|
||||
@ -312,7 +312,7 @@ void LLVMCodeGen::call_raw(const std::vector<void*>& args) {
|
||||
}
|
||||
|
||||
void LLVMCodeGen::call(const std::vector<CallArg>& args) {
|
||||
const auto& buf_args = buffer_args();
|
||||
auto& buf_args = buffer_args();
|
||||
if (args.size() != buf_args.size()) {
|
||||
throw malformed_input("wrong number of args in call");
|
||||
}
|
||||
@ -403,7 +403,7 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
|
||||
// Emit prototype and bind argument Vars to parameter indices.
|
||||
llvm::Type* retTy = dtypeToLLVM(dtype);
|
||||
std::vector<llvm::Type*> params;
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
for (auto i : c10::irange(args.size())) {
|
||||
auto const& arg = args[i];
|
||||
if (arg.isVar()) {
|
||||
params.push_back(dtypeToLLVM(arg.dtype()));
|
||||
@ -418,7 +418,7 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
|
||||
fn_->addAttribute(
|
||||
llvm::AttributeList::AttrIndex::FunctionIndex,
|
||||
llvm::Attribute::AlwaysInline);
|
||||
for (const auto i : c10::irange(args.size())) {
|
||||
for (auto i : c10::irange(args.size())) {
|
||||
if (!args[i].isVar()) {
|
||||
fn_->addParamAttr(i, llvm::Attribute::NoAlias);
|
||||
}
|
||||
@ -465,7 +465,7 @@ void LLVMCodeGenImpl::emitWrapper(const std::vector<llvm::Type*>& params) {
|
||||
auto wrapBB = llvm::BasicBlock::Create(getContext(), "wrapBB", wrapper);
|
||||
irb_.SetInsertPoint(wrapBB);
|
||||
llvm::SmallVector<llvm::Value*, 6> wrappedArgs;
|
||||
for (const auto i : c10::irange(params.size())) {
|
||||
for (auto i : c10::irange(params.size())) {
|
||||
auto argp = irb_.CreateGEP(
|
||||
wrapper->arg_begin(), llvm::ConstantInt::getSigned(IntTy_, i));
|
||||
if (params[i]->isPointerTy()) {
|
||||
@ -484,7 +484,7 @@ void LLVMCodeGenImpl::emitWrapper(const std::vector<llvm::Type*>& params) {
|
||||
|
||||
class LLVMIntrinsicsExpander : public GenericIntrinsicsExpander {
|
||||
private:
|
||||
const Expr* mutate(const Intrinsics* v) {
|
||||
Expr* mutate(Intrinsics* v) {
|
||||
if (v->op_type() == kTanh) {
|
||||
ScalarType stype = v->dtype().scalar_type();
|
||||
if (stype == ScalarType::Float) {
|
||||
@ -570,7 +570,7 @@ void LLVMCodeGenImpl::emitKernel(
|
||||
|
||||
// TODO: The binary ops are copypasta.
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Add* v) {
|
||||
void LLVMCodeGenImpl::visit(Add* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -588,7 +588,7 @@ void LLVMCodeGenImpl::visit(const Add* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Sub* v) {
|
||||
void LLVMCodeGenImpl::visit(Sub* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -606,7 +606,7 @@ void LLVMCodeGenImpl::visit(const Sub* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Mul* v) {
|
||||
void LLVMCodeGenImpl::visit(Mul* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -624,7 +624,7 @@ void LLVMCodeGenImpl::visit(const Mul* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Div* v) {
|
||||
void LLVMCodeGenImpl::visit(Div* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -642,7 +642,7 @@ void LLVMCodeGenImpl::visit(const Div* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const And* v) {
|
||||
void LLVMCodeGenImpl::visit(And* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -657,7 +657,7 @@ void LLVMCodeGenImpl::visit(const And* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Or* v) {
|
||||
void LLVMCodeGenImpl::visit(Or* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -672,7 +672,7 @@ void LLVMCodeGenImpl::visit(const Or* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Xor* v) {
|
||||
void LLVMCodeGenImpl::visit(Xor* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -687,7 +687,7 @@ void LLVMCodeGenImpl::visit(const Xor* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Lshift* v) {
|
||||
void LLVMCodeGenImpl::visit(Lshift* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -702,7 +702,7 @@ void LLVMCodeGenImpl::visit(const Lshift* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Rshift* v) {
|
||||
void LLVMCodeGenImpl::visit(Rshift* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -721,7 +721,7 @@ void LLVMCodeGenImpl::visit(const Rshift* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Mod* v) {
|
||||
void LLVMCodeGenImpl::visit(Mod* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
bool lfp = lhs->getType()->isFPOrFPVectorTy();
|
||||
@ -736,7 +736,7 @@ void LLVMCodeGenImpl::visit(const Mod* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Max* v) {
|
||||
void LLVMCodeGenImpl::visit(Max* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
v->rhs()->accept(this);
|
||||
@ -759,7 +759,7 @@ void LLVMCodeGenImpl::visit(const Max* v) {
|
||||
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs));
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Min* v) {
|
||||
void LLVMCodeGenImpl::visit(Min* v) {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
v->rhs()->accept(this);
|
||||
@ -781,7 +781,7 @@ void LLVMCodeGenImpl::visit(const Min* v) {
|
||||
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs));
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const CompareSelect* v) {
|
||||
void LLVMCodeGenImpl::visit(CompareSelect* v) {
|
||||
auto genUnbiased = [this, v]() -> llvm::Value* {
|
||||
v->lhs()->accept(this);
|
||||
auto lhs = this->value_;
|
||||
@ -906,7 +906,7 @@ llvm::Type* llvmTypeToVec(llvm::Type* type, int lanes) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Cast* v) {
|
||||
void LLVMCodeGenImpl::visit(Cast* v) {
|
||||
v->src_value()->accept(this);
|
||||
|
||||
llvm::Type* dstType =
|
||||
@ -978,7 +978,7 @@ void LLVMCodeGenImpl::visit(const Cast* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const BitCast* v) {
|
||||
void LLVMCodeGenImpl::visit(BitCast* v) {
|
||||
v->src_value()->accept(this);
|
||||
|
||||
llvm::Type* dstType = dtypeToLLVM(v->dtype());
|
||||
@ -997,11 +997,11 @@ void LLVMCodeGenImpl::visit(const BitCast* v) {
|
||||
value_ = irb_.CreateBitOrPointerCast(value_, dstType);
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Var* v) {
|
||||
void LLVMCodeGenImpl::visit(Var* v) {
|
||||
value_ = varToValue(v);
|
||||
}
|
||||
|
||||
llvm::Value* LLVMCodeGenImpl::varToValue(const Var* v) {
|
||||
llvm::Value* LLVMCodeGenImpl::varToValue(Var* v) {
|
||||
// It is possible for v to be in both varToVal_ and varToArgs.
|
||||
// In that case, varToVal_ takes precedence.
|
||||
if (varToVal_.count(v)) {
|
||||
@ -1015,11 +1015,11 @@ llvm::Value* LLVMCodeGenImpl::varToValue(const Var* v) {
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::replaceVarMapping(
|
||||
const std::vector<const Var*>& vars,
|
||||
const std::vector<Var*>& vars,
|
||||
const std::vector<llvm::Value*>& vals) {
|
||||
TORCH_CHECK(vars.size() == vals.size());
|
||||
for (const auto i : c10::irange(vars.size())) {
|
||||
const Var* var = vars[i];
|
||||
for (auto i : c10::irange(vars.size())) {
|
||||
Var* var = vars[i];
|
||||
llvm::Value* val = vals[i];
|
||||
if (val) {
|
||||
varToVal_[var] = val;
|
||||
@ -1029,7 +1029,7 @@ void LLVMCodeGenImpl::replaceVarMapping(
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Ramp* v) {
|
||||
void LLVMCodeGenImpl::visit(Ramp* v) {
|
||||
v->base()->accept(this);
|
||||
auto base = this->value_;
|
||||
v->stride()->accept(this);
|
||||
@ -1105,7 +1105,7 @@ llvm::Value* LLVMCodeGenImpl::emitMaskedLoad(
|
||||
return phi;
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Load* v) {
|
||||
void LLVMCodeGenImpl::visit(Load* v) {
|
||||
if (v->dtype().lanes() == 1) {
|
||||
v->base_handle()->accept(this);
|
||||
auto base = this->value_;
|
||||
@ -1134,9 +1134,9 @@ void LLVMCodeGenImpl::visit(const Load* v) {
|
||||
bool unmasked_load = true;
|
||||
|
||||
// Handle the case where the load is contiguous and unmasked efficiently
|
||||
auto* idx_ramp = dynamic_cast<const Ramp*>(v->flat_index());
|
||||
auto* idx_ramp = dynamic_cast<Ramp*>(v->flat_index());
|
||||
if (idx_ramp) {
|
||||
auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
|
||||
auto* stride_imm = dynamic_cast<IntImm*>(idx_ramp->stride());
|
||||
if (stride_imm && stride_imm->value() == 1) {
|
||||
v->base_handle()->accept(this);
|
||||
auto base = this->value_;
|
||||
@ -1181,14 +1181,14 @@ llvm::Value* LLVMCodeGenImpl::packFuncArgs(
|
||||
return NullPtr;
|
||||
}
|
||||
std::vector<llvm::Type*> arg_types(func_args.size());
|
||||
for (const auto i : c10::irange(func_args.size())) {
|
||||
for (auto i : c10::irange(func_args.size())) {
|
||||
arg_types[i] = func_args[i]->getType();
|
||||
}
|
||||
llvm::StructType* packed_type = llvm::StructType::create(arg_types);
|
||||
llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
|
||||
llvm::Value* one = llvm::ConstantInt::get(IntTy_, 1);
|
||||
llvm::Value* packed = irb_.CreateAlloca(packed_type, one);
|
||||
for (const auto i : c10::irange(func_args.size())) {
|
||||
for (auto i : c10::irange(func_args.size())) {
|
||||
llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
|
||||
packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
|
||||
irb_.CreateStore(func_args[i], dst_ptr);
|
||||
@ -1203,7 +1203,7 @@ std::vector<llvm::Value*> LLVMCodeGenImpl::unpackFuncArgs(
|
||||
// TODO: extract arg_count from packed.
|
||||
std::vector<llvm::Value*> func_args(arg_count);
|
||||
llvm::Value* zero = llvm::ConstantInt::get(IntTy_, 0);
|
||||
for (const auto i : c10::irange(arg_count)) {
|
||||
for (auto i : c10::irange(arg_count)) {
|
||||
llvm::Value* dst_ptr = irb_.CreateInBoundsGEP(
|
||||
packed, {zero, llvm::ConstantInt::get(IntTy_, i)});
|
||||
func_args[i] = irb_.CreateLoad(dst_ptr);
|
||||
@ -1215,7 +1215,7 @@ std::vector<llvm::Value*> LLVMCodeGenImpl::unpackFuncArgs(
|
||||
// * Move the body into its own closure.
|
||||
// * Identify var across the boundary into arguments and forward them.
|
||||
// * Send the closure and range to the dispatcher for execution.
|
||||
void LLVMCodeGenImpl::processParallelFor(const For* v) {
|
||||
void LLVMCodeGenImpl::processParallelFor(For* v) {
|
||||
// Create "start" and "stop" values.
|
||||
v->start()->accept(this);
|
||||
auto start = this->value_;
|
||||
@ -1223,7 +1223,7 @@ void LLVMCodeGenImpl::processParallelFor(const For* v) {
|
||||
auto stop = this->value_;
|
||||
|
||||
// The Vars that need to be forward in the body closure.
|
||||
std::vector<const Var*> body_arg_vars;
|
||||
std::vector<Var*> body_arg_vars;
|
||||
// Corresponding Value* that was used in the old body for the caller.
|
||||
std::vector<llvm::Value*> body_caller_vals;
|
||||
// Corresponding Value* that will be used in the new body closure.
|
||||
@ -1232,7 +1232,7 @@ void LLVMCodeGenImpl::processParallelFor(const For* v) {
|
||||
// Identify the Var* used in the body, and generated outside.
|
||||
VarFinder var_finder;
|
||||
v->body()->accept(&var_finder);
|
||||
const auto& vars = var_finder.vars();
|
||||
auto& vars = var_finder.vars();
|
||||
for (auto& var : vars) {
|
||||
if (llvm::Value* value = varToValue(var)) {
|
||||
body_arg_vars.push_back(var);
|
||||
@ -1292,7 +1292,7 @@ void LLVMCodeGenImpl::processParallelFor(const For* v) {
|
||||
value_ = llvm::ConstantInt::get(IntTy_, 0);
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const For* v) {
|
||||
void LLVMCodeGenImpl::visit(For* v) {
|
||||
if (v->is_parallel()) {
|
||||
processParallelFor(v);
|
||||
return;
|
||||
@ -1347,8 +1347,8 @@ void LLVMCodeGenImpl::visit(const For* v) {
|
||||
value_ = llvm::ConstantInt::get(IntTy_, 0);
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Block* v) {
|
||||
const Block* last = scope_;
|
||||
void LLVMCodeGenImpl::visit(Block* v) {
|
||||
Block* last = scope_;
|
||||
scope_ = v;
|
||||
|
||||
for (Stmt* s : *v) {
|
||||
@ -1359,7 +1359,7 @@ void LLVMCodeGenImpl::visit(const Block* v) {
|
||||
|
||||
auto it = scopeToVar_.find(v);
|
||||
if (it != scopeToVar_.end()) {
|
||||
for (const Var* e : it->second) {
|
||||
for (Var* e : it->second) {
|
||||
if (varToVal_.erase(e) != 1) {
|
||||
throw std::runtime_error("erasing var that doesn't exist");
|
||||
}
|
||||
@ -1398,7 +1398,7 @@ void LLVMCodeGenImpl::emitMaskedStore(
|
||||
irb_.SetInsertPoint(tailblock);
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Store* v) {
|
||||
void LLVMCodeGenImpl::visit(Store* v) {
|
||||
if (v->value()->dtype().lanes() == 1) {
|
||||
v->base_handle()->accept(this);
|
||||
auto base = this->value_;
|
||||
@ -1419,9 +1419,9 @@ void LLVMCodeGenImpl::visit(const Store* v) {
|
||||
auto val = this->value_;
|
||||
|
||||
// Handle the case where the store is contiguous and unmasked efficiently
|
||||
auto* idx_ramp = dynamic_cast<const Ramp*>(v->flat_index());
|
||||
auto* idx_ramp = dynamic_cast<Ramp*>(v->flat_index());
|
||||
if (idx_ramp) {
|
||||
auto* stride_imm = dynamic_cast<const IntImm*>(idx_ramp->stride());
|
||||
auto* stride_imm = dynamic_cast<IntImm*>(idx_ramp->stride());
|
||||
if (stride_imm && stride_imm->value() == 1) {
|
||||
idx_ramp->base()->accept(this);
|
||||
auto first_idx = value_;
|
||||
@ -1453,13 +1453,13 @@ void LLVMCodeGenImpl::visit(const Store* v) {
|
||||
value_ = llvm::ConstantInt::get(IntTy_, 0);
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Broadcast* v) {
|
||||
void LLVMCodeGenImpl::visit(Broadcast* v) {
|
||||
v->value()->accept(this);
|
||||
int lanes = v->lanes();
|
||||
value_ = irb_.CreateVectorSplat(lanes, value_);
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const IfThenElse* v) {
|
||||
void LLVMCodeGenImpl::visit(IfThenElse* v) {
|
||||
v->condition()->accept(this);
|
||||
llvm::Value* condition = value_;
|
||||
llvm::Value* c = irb_.CreateICmpNE(
|
||||
@ -1509,7 +1509,7 @@ llvm::Value* LLVMCodeGenImpl::toVec(llvm::Value* v, int lanes) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::emitIsNan(const Intrinsics* v) {
|
||||
void LLVMCodeGenImpl::emitIsNan(Intrinsics* v) {
|
||||
v->param(0)->accept(this);
|
||||
llvm::Type* dstType = dtypeToLLVM(v->dtype());
|
||||
if (!v->param(0)->dtype().is_floating_point()) {
|
||||
@ -1583,7 +1583,7 @@ LLVMCodeGenImpl::SimdCallee LLVMCodeGenImpl::getSimdFunction(
|
||||
return SimdCallee{callee.getFunctionType(), callee.getCallee(), useSimd};
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Intrinsics* v) {
|
||||
void LLVMCodeGenImpl::visit(Intrinsics* v) {
|
||||
llvm::FunctionType* call_ty = nullptr;
|
||||
llvm::Value* call_fn = nullptr;
|
||||
bool call_simd_sleef = false;
|
||||
@ -1772,7 +1772,7 @@ void LLVMCodeGenImpl::visit(const Intrinsics* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const ExternalCall* v) {
|
||||
void LLVMCodeGenImpl::visit(ExternalCall* v) {
|
||||
constexpr int max_buffers = 10;
|
||||
constexpr int max_dimensions = 40;
|
||||
|
||||
@ -1783,7 +1783,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
|
||||
|
||||
// Prepare a vector of bufs that we need to pass to the external function.
|
||||
// This vector is the output buf followed by the buf_args.
|
||||
std::vector<const Buf*> bufs(v->buf_args());
|
||||
std::vector<Buf*> bufs(v->buf_args());
|
||||
bufs.insert(bufs.begin(), v->buf());
|
||||
|
||||
int64_t bufs_num = bufs.size();
|
||||
@ -1792,7 +1792,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
|
||||
// Count the size of dims array - it consists of dimension of all bufs
|
||||
// concatenated together.
|
||||
int64_t dims_num = 0;
|
||||
for (const Buf* b : bufs) {
|
||||
for (Buf* b : bufs) {
|
||||
dims_num += b->dims().size();
|
||||
}
|
||||
|
||||
@ -1809,7 +1809,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
|
||||
|
||||
int i = 0;
|
||||
int dim_idx = 0;
|
||||
for (const Buf* b : bufs) {
|
||||
for (Buf* b : bufs) {
|
||||
// Store value for buf pointer
|
||||
auto gep = irb_.CreateInBoundsGEP(
|
||||
buf_ptrs, {llvm::ConstantInt::getSigned(IntTy_, i)});
|
||||
@ -1832,7 +1832,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
|
||||
llvm::ConstantInt::getSigned(LongTy_, b->dims().size()), gep);
|
||||
|
||||
// Store dims of the buf
|
||||
for (const auto dim : c10::irange(b->dims().size())) {
|
||||
for (auto dim : c10::irange(b->dims().size())) {
|
||||
gep = irb_.CreateInBoundsGEP(
|
||||
buf_dims, {llvm::ConstantInt::getSigned(IntTy_, dim_idx)});
|
||||
b->dims()[dim]->accept(this);
|
||||
@ -1845,7 +1845,7 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
|
||||
}
|
||||
|
||||
i = 0;
|
||||
for (const Expr* arg : v->args()) {
|
||||
for (Expr* arg : v->args()) {
|
||||
auto gep = irb_.CreateInBoundsGEP(
|
||||
extra_args, {llvm::ConstantInt::getSigned(IntTy_, i)});
|
||||
arg->accept(this);
|
||||
@ -1886,10 +1886,10 @@ void LLVMCodeGenImpl::visit(const ExternalCall* v) {
|
||||
value_ = llvm::ConstantInt::get(IntTy_, 0);
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Allocate* v) {
|
||||
void LLVMCodeGenImpl::visit(Allocate* v) {
|
||||
llvm::Value* size =
|
||||
llvm::ConstantInt::getSigned(LongTy_, v->dtype().byte_size());
|
||||
for (const Expr* e : v->dims()) {
|
||||
for (Expr* e : v->dims()) {
|
||||
e->accept(this);
|
||||
size = irb_.CreateMul(size, irb_.CreateZExt(value_, LongTy_));
|
||||
}
|
||||
@ -1918,7 +1918,7 @@ void LLVMCodeGenImpl::visit(const Allocate* v) {
|
||||
varToVal_[v->buffer_var()] = malloc;
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Free* v) {
|
||||
void LLVMCodeGenImpl::visit(Free* v) {
|
||||
value_ = llvm::ConstantInt::get(IntTy_, 0);
|
||||
llvm::Value* ptr = varToVal_.at(v->buffer_var());
|
||||
if (!llvm::isa<llvm::AllocaInst>(ptr)) {
|
||||
@ -1926,7 +1926,7 @@ void LLVMCodeGenImpl::visit(const Free* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Let* v) {
|
||||
void LLVMCodeGenImpl::visit(Let* v) {
|
||||
v->value()->accept(this);
|
||||
if (!varToVal_.count(v->var())) {
|
||||
varToVal_.emplace(v->var(), value_);
|
||||
@ -1936,7 +1936,7 @@ void LLVMCodeGenImpl::visit(const Let* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void LLVMCodeGenImpl::visit(const Cond* v) {
|
||||
void LLVMCodeGenImpl::visit(Cond* v) {
|
||||
// Even if true_stmt and false_stmt are nullptr,
|
||||
// in case condition is a function call with side effect,
|
||||
// we still evaluate it.
|
||||
|
@ -113,7 +113,7 @@ static void registerIntrinsics(
|
||||
}
|
||||
assertSuccess(JD.define(absoluteSymbols(symbols)));
|
||||
|
||||
for (const auto& kv : getNNCFunctionRegistry()) {
|
||||
for (auto& kv : getNNCFunctionRegistry()) {
|
||||
assertSuccess(
|
||||
JD.define(absoluteSymbols({entry(kv.first.c_str(), kv.second)})));
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -34,7 +34,7 @@ class TORCH_API LoopNest {
|
||||
|
||||
// A constructor for building a LoopNest from an Stmt and a list of output
|
||||
// buffers.
|
||||
LoopNest(Stmt* stmt, std::unordered_set<const Buf*> output_bufs);
|
||||
LoopNest(Stmt* stmt, std::unordered_set<Buf*> output_bufs);
|
||||
|
||||
// A constructor for building a LoopNest from another loopnest. It clones the
|
||||
// other loopnest's stmt.
|
||||
@ -45,10 +45,10 @@ class TORCH_API LoopNest {
|
||||
}
|
||||
|
||||
std::vector<For*> getLoopStmtsFor(Tensor*) const;
|
||||
std::vector<For*> getLoopStmtsFor(const Buf*) const;
|
||||
std::vector<For*> getLoopStmtsFor(Buf*) const;
|
||||
std::vector<For*> getLoopStmtsFor(Stmt*) const;
|
||||
Stmt* getLoopBodyFor(Tensor*) const;
|
||||
Stmt* getLoopBodyFor(const Buf*) const;
|
||||
Stmt* getLoopBodyFor(Buf*) const;
|
||||
|
||||
// Returns the For stmt indexed by 'indices' in the 'root' For stmt.
|
||||
//'indices' indicates the path to the returned loop from 'root' in AST, e.g.,
|
||||
@ -71,14 +71,14 @@ class TORCH_API LoopNest {
|
||||
For* getLoopAt(For* root, const std::vector<int>& indices) const;
|
||||
|
||||
// Returns the For stmt that is immediately enclosing the given stmt.
|
||||
static For* getParentLoop(const Stmt* st);
|
||||
static For* getParentLoop(Stmt* st);
|
||||
|
||||
// Returns the list of For stmts corresponding to the loopnest that is
|
||||
// enclosing the given stmt.
|
||||
static std::vector<For*> getEnclosingLoopNest(const Stmt* st);
|
||||
static std::vector<For*> getEnclosingLoopNest(Stmt* st);
|
||||
|
||||
// Returns a list of all Stmts that write to the given buf.
|
||||
std::vector<const Stmt*> getAllWritesToBuf(const Buf*) const;
|
||||
std::vector<Stmt*> getAllWritesToBuf(Buf*) const;
|
||||
|
||||
// The following methods return the For loops that contain writes to
|
||||
// the given buf.
|
||||
@ -98,18 +98,18 @@ class TORCH_API LoopNest {
|
||||
// to buf.
|
||||
// For the above example:
|
||||
// getAllInnermostLoopsWritingToBuf(a) => {j1, k2, j3}
|
||||
std::vector<For*> getAllInnermostLoopsWritingToBuf(const Buf*) const;
|
||||
std::vector<For*> getAllInnermostLoopsWritingToBuf(Buf*) const;
|
||||
|
||||
// Returns a list of For loopnests which contain a Stmt that writes to
|
||||
// the given buf. Each loopnest here is a vector For loops.
|
||||
// For the above example:
|
||||
// getAllLoopNestsWritingToBuf(a) => {{i1,j1}, {i2,j2,k2}, {i2,j3}}
|
||||
std::vector<std::vector<For*>> getAllLoopNestsWritingToBuf(const Buf*) const;
|
||||
std::vector<std::vector<For*>> getAllLoopNestsWritingToBuf(Buf*) const;
|
||||
|
||||
Stmt* simplify();
|
||||
|
||||
bool computeInline(Stmt* s);
|
||||
bool computeInline(const Buf* b);
|
||||
bool computeInline(Buf* b);
|
||||
void inlineIntermediateBufs(bool allow_duplicated_work);
|
||||
|
||||
// Optimizes conditionals.
|
||||
@ -463,12 +463,12 @@ class TORCH_API LoopNest {
|
||||
static void sliceTail(For* f, int factor, For** head, For** tail);
|
||||
static void sliceTail(For* f, int factor);
|
||||
|
||||
using AccessResult = std::pair<const Buf*, Stmt*>;
|
||||
using AccessResult = std::pair<Buf*, Stmt*>;
|
||||
// Insert a cache for the consumer's usages of the buffer produced in
|
||||
// consumer, and redirect reads and writes in the consumer to that cache.
|
||||
// Returns a pair of the new cache buffer, and the new rewritten consumer.
|
||||
static AccessResult cacheAccesses(
|
||||
const Buf* producer,
|
||||
Buf* producer,
|
||||
const std::string& name,
|
||||
Stmt* consumer);
|
||||
|
||||
@ -535,8 +535,8 @@ class TORCH_API LoopNest {
|
||||
void eliminateDeadStores();
|
||||
void prepareForCodegen();
|
||||
|
||||
const std::unordered_set<const Buf*> getInputBufs() const;
|
||||
const std::unordered_set<const Buf*> getOutputBufs() const {
|
||||
const std::unordered_set<Buf*> getInputBufs() const;
|
||||
const std::unordered_set<Buf*> getOutputBufs() const {
|
||||
return output_bufs_;
|
||||
}
|
||||
|
||||
@ -545,11 +545,11 @@ class TORCH_API LoopNest {
|
||||
const std::vector<Tensor*>& output_tensors,
|
||||
const std::vector<Tensor*>& tensors_to_compute);
|
||||
Stmt* insertAllocFree(Stmt* stmt);
|
||||
const std::unordered_set<const Buf*> getIntermediateBufs() const;
|
||||
const std::unordered_set<Buf*> getIntermediateBufs() const;
|
||||
|
||||
Stmt* root_stmt_;
|
||||
|
||||
std::unordered_set<const Buf*> output_bufs_;
|
||||
std::unordered_set<Buf*> output_bufs_;
|
||||
};
|
||||
|
||||
TORCH_API Stmt* FlattenIndexes(Stmt* s);
|
||||
@ -568,8 +568,8 @@ struct BufLoadOrStoreUse {
|
||||
* in the vectors reflects the order in which the uses appear in the given
|
||||
* statement.
|
||||
*/
|
||||
std::unordered_map<const Buf*, std::vector<BufLoadOrStoreUse>>
|
||||
findLoadOrStoreUses(Stmt* s);
|
||||
std::unordered_map<Buf*, std::vector<BufLoadOrStoreUse>> findLoadOrStoreUses(
|
||||
Stmt* s);
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
|
@ -59,15 +59,15 @@ void getDependentsChain(
|
||||
|
||||
// AccessInfo
|
||||
|
||||
std::vector<const Expr*> AccessInfo::getIndices() const {
|
||||
std::vector<const Expr*> indices;
|
||||
std::vector<Expr*> AccessInfo::getIndices() const {
|
||||
std::vector<Expr*> indices;
|
||||
|
||||
if (expr_) {
|
||||
if (auto* load = dynamic_cast<const Load*>(expr_)) {
|
||||
if (auto* load = dynamic_cast<Load*>(expr_)) {
|
||||
indices = load->indices();
|
||||
}
|
||||
} else {
|
||||
if (auto* store = dynamic_cast<const Store*>(stmt_)) {
|
||||
if (auto* store = dynamic_cast<Store*>(stmt_)) {
|
||||
indices = store->indices();
|
||||
}
|
||||
}
|
||||
@ -255,8 +255,8 @@ MemDependencyChecker::MemDependencyChecker() {
|
||||
}
|
||||
|
||||
MemDependencyChecker::MemDependencyChecker(
|
||||
const std::unordered_set<const Buf*>& inputs,
|
||||
const std::unordered_set<const Buf*>& outputs) {
|
||||
const std::unordered_set<Buf*>& inputs,
|
||||
const std::unordered_set<Buf*>& outputs) {
|
||||
for (auto* s : inputs) {
|
||||
inputs_[s] = nullptr;
|
||||
}
|
||||
@ -320,15 +320,15 @@ DependencySet MemDependencyChecker::getAllWriteDependencies(
|
||||
return writes;
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsDirectly(const Expr* A, const Stmt* B) {
|
||||
bool MemDependencyChecker::dependsDirectly(Expr* A, Stmt* B) {
|
||||
return dependsDirectlyHelper(A, B);
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsDirectly(const Stmt* A, const Stmt* B) {
|
||||
bool MemDependencyChecker::dependsDirectly(Stmt* A, Stmt* B) {
|
||||
return dependsDirectlyHelper(A, B);
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsDirectly(const Buf* O, const Stmt* B) {
|
||||
bool MemDependencyChecker::dependsDirectly(Buf* O, Stmt* B) {
|
||||
auto outputAccess = output(O);
|
||||
auto bWrites = getAllWritesWithin(B);
|
||||
|
||||
@ -341,7 +341,7 @@ bool MemDependencyChecker::dependsDirectly(const Buf* O, const Stmt* B) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsDirectly(const Stmt* A, const Buf* I) {
|
||||
bool MemDependencyChecker::dependsDirectly(Stmt* A, Buf* I) {
|
||||
auto aReads = getAllReadsWithin(A);
|
||||
auto inputAccess = input(I);
|
||||
|
||||
@ -354,7 +354,7 @@ bool MemDependencyChecker::dependsDirectly(const Stmt* A, const Buf* I) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsDirectly(const Expr* A, const Buf* I) {
|
||||
bool MemDependencyChecker::dependsDirectly(Expr* A, Buf* I) {
|
||||
auto aReads = getAllReadsWithin(A);
|
||||
auto inputAccess = input(I);
|
||||
|
||||
@ -373,15 +373,15 @@ bool MemDependencyChecker::dependsDirectly(
|
||||
return A->hasDependency(B) && B->isWrite();
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsIndirectly(const Expr* A, const Stmt* B) {
|
||||
bool MemDependencyChecker::dependsIndirectly(Expr* A, Stmt* B) {
|
||||
return dependsIndirectlyHelper(A, B);
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsIndirectly(const Stmt* A, const Stmt* B) {
|
||||
bool MemDependencyChecker::dependsIndirectly(Stmt* A, Stmt* B) {
|
||||
return dependsIndirectlyHelper(A, B);
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsIndirectly(const Buf* O, const Stmt* B) {
|
||||
bool MemDependencyChecker::dependsIndirectly(Buf* O, Stmt* B) {
|
||||
auto outputAccess = output(O);
|
||||
|
||||
DependencySet dependencies;
|
||||
@ -397,7 +397,7 @@ bool MemDependencyChecker::dependsIndirectly(const Buf* O, const Stmt* B) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsIndirectly(const Stmt* A, const Buf* I) {
|
||||
bool MemDependencyChecker::dependsIndirectly(Stmt* A, Buf* I) {
|
||||
auto aReads = getAllReadsWithin(A);
|
||||
auto inputAccess = input(I);
|
||||
|
||||
@ -406,7 +406,7 @@ bool MemDependencyChecker::dependsIndirectly(const Stmt* A, const Buf* I) {
|
||||
return aDeps.count(inputAccess) != 0;
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsIndirectly(const Expr* A, const Buf* I) {
|
||||
bool MemDependencyChecker::dependsIndirectly(Expr* A, Buf* I) {
|
||||
auto aReads = getAllReadsWithin(A);
|
||||
auto inputAccess = input(I);
|
||||
|
||||
@ -415,7 +415,7 @@ bool MemDependencyChecker::dependsIndirectly(const Expr* A, const Buf* I) {
|
||||
return aDeps.count(inputAccess) != 0;
|
||||
}
|
||||
|
||||
bool MemDependencyChecker::dependsIndirectly(const Buf* O, const Buf* I) {
|
||||
bool MemDependencyChecker::dependsIndirectly(Buf* O, Buf* I) {
|
||||
auto outputAccess = output(O);
|
||||
auto inputAccess = input(I);
|
||||
|
||||
@ -438,8 +438,7 @@ bool MemDependencyChecker::dependsIndirectly(
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
|
||||
const Stmt* A) const {
|
||||
std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(Stmt* A) const {
|
||||
auto bound = stmtToAccess_.equal_range(A);
|
||||
for (auto it = bound.first; it != bound.second; ++it) {
|
||||
if (it->second->expr() == nullptr) {
|
||||
@ -449,8 +448,7 @@ std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
|
||||
const Expr* A) const {
|
||||
std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(Expr* A) const {
|
||||
// TODO exprs can have multiple accesses... we're returning the first but that
|
||||
// isn't great. Can't do much here.
|
||||
auto bound = exprToAccess_.equal_range(A);
|
||||
@ -462,7 +460,7 @@ std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
|
||||
}
|
||||
|
||||
std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
|
||||
accessesWithin(const Stmt* A) const {
|
||||
accessesWithin(Stmt* A) const {
|
||||
auto it = scopeToAccesses_.find(A);
|
||||
if (it != scopeToAccesses_.end()) {
|
||||
return std::unordered_set<std::shared_ptr<AccessInfo>>(
|
||||
@ -478,11 +476,11 @@ std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
|
||||
}
|
||||
|
||||
std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
|
||||
accessesWithin(const Expr* A) const {
|
||||
accessesWithin(Expr* A) const {
|
||||
return {accessFor(A)};
|
||||
}
|
||||
|
||||
std::shared_ptr<AccessInfo> MemDependencyChecker::input(const Buf* b) const {
|
||||
std::shared_ptr<AccessInfo> MemDependencyChecker::input(Buf* b) const {
|
||||
auto it = inputs_.find(b);
|
||||
if (it == inputs_.end()) {
|
||||
return nullptr;
|
||||
@ -490,7 +488,7 @@ std::shared_ptr<AccessInfo> MemDependencyChecker::input(const Buf* b) const {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::shared_ptr<AccessInfo> MemDependencyChecker::output(const Buf* b) const {
|
||||
std::shared_ptr<AccessInfo> MemDependencyChecker::output(Buf* b) const {
|
||||
auto it = outputs_.find(b);
|
||||
if (it == outputs_.end()) {
|
||||
return nullptr;
|
||||
@ -500,18 +498,18 @@ std::shared_ptr<AccessInfo> MemDependencyChecker::output(const Buf* b) const {
|
||||
|
||||
// Node visitors:
|
||||
|
||||
void MemDependencyChecker::visit(const Store* v) {
|
||||
const Stmt* last = lastStmt_;
|
||||
void MemDependencyChecker::visit(Store* v) {
|
||||
Stmt* last = lastStmt_;
|
||||
lastStmt_ = v;
|
||||
v->value()->accept(this);
|
||||
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
ind->accept(this);
|
||||
}
|
||||
lastStmt_ = last;
|
||||
|
||||
// Create a new AccessInfo for the store.
|
||||
const Var* var = v->buf()->base_handle();
|
||||
Var* var = v->buf()->base_handle();
|
||||
auto info = std::make_shared<AccessInfo>(
|
||||
nextAccess_++, AccessType::Store, v, var, getIndicesBounds(v->indices()));
|
||||
|
||||
@ -532,19 +530,19 @@ void MemDependencyChecker::visit(const Store* v) {
|
||||
currentScope_->accesses_.push_back(info);
|
||||
}
|
||||
|
||||
void MemDependencyChecker::visit(const Load* v) {
|
||||
void MemDependencyChecker::visit(Load* v) {
|
||||
// Create a temporary scope to hold any loads that occur within the indices of
|
||||
// this load.
|
||||
auto indicesScope =
|
||||
std::make_shared<Scope>(currentScope_->block, currentScope_);
|
||||
currentScope_ = indicesScope;
|
||||
|
||||
for (const Expr* ind : v->indices()) {
|
||||
for (Expr* ind : v->indices()) {
|
||||
ind->accept(this);
|
||||
}
|
||||
|
||||
// Create a new AccessInfo for the load.
|
||||
const Var* var = v->buf()->base_handle();
|
||||
Var* var = v->buf()->base_handle();
|
||||
auto load = std::make_shared<AccessInfo>(
|
||||
nextAccess_++,
|
||||
AccessType::Load,
|
||||
@ -584,24 +582,24 @@ void MemDependencyChecker::visit(const Load* v) {
|
||||
bool executionSafetyCheck(
|
||||
const std::shared_ptr<AccessInfo>& info,
|
||||
const std::shared_ptr<AccessInfo>& other,
|
||||
const std::vector<const Expr*>& aStrides,
|
||||
const std::vector<const Expr*>& oStrides,
|
||||
const std::vector<Expr*>& aStrides,
|
||||
const std::vector<Expr*>& oStrides,
|
||||
bool parallelized) {
|
||||
if (aStrides.empty() || oStrides.empty()) {
|
||||
return false;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(info->bounds().size() == other->bounds().size());
|
||||
for (size_t b = 0; b < info->bounds().size(); ++b) {
|
||||
const Expr* aIndexStride = aStrides[b];
|
||||
const Expr* oIndexStride = oStrides[b];
|
||||
Expr* aIndexStride = aStrides[b];
|
||||
Expr* oIndexStride = oStrides[b];
|
||||
// can't be safe on this index if we can't determine stride.
|
||||
if (!aIndexStride->isConstant() || !oIndexStride->isConstant()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Expr* minStride =
|
||||
Expr* minStride =
|
||||
IRSimplifier::simplify(new Min(aIndexStride, oIndexStride, true));
|
||||
const Expr* maxStride =
|
||||
Expr* maxStride =
|
||||
IRSimplifier::simplify(new Max(aIndexStride, oIndexStride, true));
|
||||
|
||||
// If the first access has no stride don't apply safety).
|
||||
@ -609,8 +607,7 @@ bool executionSafetyCheck(
|
||||
continue;
|
||||
}
|
||||
|
||||
const Expr* modCheck =
|
||||
IRSimplifier::simplify(new Mod(maxStride, minStride));
|
||||
Expr* modCheck = IRSimplifier::simplify(new Mod(maxStride, minStride));
|
||||
|
||||
// if the strides can't have easily inferable distinct offsets, they're not
|
||||
// safe.
|
||||
@ -624,7 +621,7 @@ bool executionSafetyCheck(
|
||||
// axis is the same sign as the common stride, then they will not
|
||||
// overlap.
|
||||
|
||||
const Expr* startDiff = IRSimplifier::simplify(
|
||||
Expr* startDiff = IRSimplifier::simplify(
|
||||
new Sub(info->bounds()[b].start, other->bounds()[b].start));
|
||||
|
||||
bool diffNegative = immediateIsNegative(startDiff);
|
||||
@ -638,7 +635,7 @@ bool executionSafetyCheck(
|
||||
// If both accesses have the same stride, and the difference in start
|
||||
// element is smaller than this stride then the entire range is distinct.
|
||||
if (exprEquals(minStride, maxStride)) {
|
||||
const Expr* check1 =
|
||||
Expr* check1 =
|
||||
IRSimplifier::simplify(new CompareSelect(startDiff, minStride, kLT));
|
||||
if (check1->isConstant() && immediateEquals(check1, 1)) {
|
||||
return true;
|
||||
@ -649,7 +646,7 @@ bool executionSafetyCheck(
|
||||
|
||||
CompareSelectOperation op = strideNegative ? kLT : kGT;
|
||||
|
||||
const Expr* check =
|
||||
Expr* check =
|
||||
IRSimplifier::simplify(new CompareSelect(startDiff, new IntImm(0), op));
|
||||
|
||||
// If the start difference modulo the minimum stride is offset from that
|
||||
@ -670,10 +667,10 @@ bool executionSafetyCheck(
|
||||
return false;
|
||||
}
|
||||
|
||||
void MemDependencyChecker::visit(const For* v) {
|
||||
const Var* var = v->var();
|
||||
void MemDependencyChecker::visit(For* v) {
|
||||
Var* var = v->var();
|
||||
|
||||
const Stmt* last = lastStmt_;
|
||||
Stmt* last = lastStmt_;
|
||||
lastStmt_ = v;
|
||||
|
||||
v->var()->accept(this);
|
||||
@ -716,19 +713,19 @@ void MemDependencyChecker::visit(const For* v) {
|
||||
// access, which we do via substituting the loop var with (var+1) into the
|
||||
// indices expr.
|
||||
|
||||
std::vector<std::vector<const Expr*>> loopStrides;
|
||||
std::vector<std::vector<Expr*>> loopStrides;
|
||||
loopStrides.resize(currentScope_->accesses_.size());
|
||||
|
||||
for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) {
|
||||
auto& info = currentScope_->accesses_[a];
|
||||
|
||||
std::vector<const Expr*> indices = info->getIndices();
|
||||
std::vector<Expr*> indices = info->getIndices();
|
||||
|
||||
std::vector<const Expr*>& loopIndicesStride = loopStrides[a];
|
||||
std::vector<Expr*>& loopIndicesStride = loopStrides[a];
|
||||
loopIndicesStride.resize(indices.size());
|
||||
|
||||
// index expr must depend on the loop var in some way to have a stride.
|
||||
for (const auto i : c10::irange(indices.size())) {
|
||||
for (auto i : c10::irange(indices.size())) {
|
||||
VarFinder vf;
|
||||
if (vf.find(indices[i]).count(var) == 0) {
|
||||
loopIndicesStride[i] = new IntImm(0);
|
||||
@ -750,14 +747,14 @@ void MemDependencyChecker::visit(const For* v) {
|
||||
{{var, new Sub(v->stop(), new IntImm(1))}}));
|
||||
}
|
||||
|
||||
const Expr* zeroStep = indices[i];
|
||||
const Expr* oneStep =
|
||||
Expr* zeroStep = indices[i];
|
||||
Expr* oneStep =
|
||||
Substitute(indices[i], {{var, new Add(var, new IntImm(1))}});
|
||||
loopIndicesStride[i] =
|
||||
IRSimplifier::simplify(new Sub(oneStep, zeroStep));
|
||||
|
||||
// If the start < end then swap the order of the bound.
|
||||
const Expr* diff = IRSimplifier::simplify(
|
||||
Expr* diff = IRSimplifier::simplify(
|
||||
new Sub(info->bounds()[i].end, info->bounds()[i].start));
|
||||
if (diff->isConstant() && immediateIsNegative(diff)) {
|
||||
info->bounds()[i].swap();
|
||||
@ -788,8 +785,7 @@ void MemDependencyChecker::visit(const For* v) {
|
||||
Substitute(bound.end, {{var, new Sub(v->stop(), new IntImm(1))}}));
|
||||
|
||||
// If the start < end then swap the order of the bound.
|
||||
const Expr* diff =
|
||||
IRSimplifier::simplify(new Sub(bound.end, bound.start));
|
||||
Expr* diff = IRSimplifier::simplify(new Sub(bound.end, bound.start));
|
||||
if (diff->isConstant() && immediateIsNegative(diff)) {
|
||||
bound.swap();
|
||||
}
|
||||
@ -802,7 +798,7 @@ void MemDependencyChecker::visit(const For* v) {
|
||||
v->loop_options().is_gpu_thread_index();
|
||||
|
||||
// Store buffers allocated at this scope.
|
||||
std::unordered_set<const Var*> local_intermediates;
|
||||
std::unordered_set<Var*> local_intermediates;
|
||||
|
||||
// Scanning from the top of the loop, we look for accesses which may depend
|
||||
// on a previous or parallel loop iteration.
|
||||
@ -905,8 +901,8 @@ void MemDependencyChecker::visit(const For* v) {
|
||||
currentScope_ = currentScope_->parent;
|
||||
}
|
||||
|
||||
void MemDependencyChecker::visit(const Cond* v) {
|
||||
const Stmt* last = lastStmt_;
|
||||
void MemDependencyChecker::visit(Cond* v) {
|
||||
Stmt* last = lastStmt_;
|
||||
lastStmt_ = v;
|
||||
|
||||
auto enclosingScope =
|
||||
@ -954,12 +950,12 @@ void MemDependencyChecker::visit(const Cond* v) {
|
||||
lastStmt_ = last;
|
||||
}
|
||||
|
||||
void MemDependencyChecker::visit(const IfThenElse* v) {
|
||||
void MemDependencyChecker::visit(IfThenElse* v) {
|
||||
// condition is in enclosing scope.
|
||||
v->condition()->accept(this);
|
||||
|
||||
const Expr* true_value = v->true_value();
|
||||
const Expr* false_value = v->false_value();
|
||||
Expr* true_value = v->true_value();
|
||||
Expr* false_value = v->false_value();
|
||||
|
||||
auto enclosingScope = currentScope_;
|
||||
|
||||
@ -990,13 +986,13 @@ void MemDependencyChecker::visit(const IfThenElse* v) {
|
||||
currentScope_ = enclosingScope;
|
||||
}
|
||||
|
||||
void MemDependencyChecker::visit(const CompareSelect* v) {
|
||||
void MemDependencyChecker::visit(CompareSelect* v) {
|
||||
// condition is in enclosing scope.
|
||||
v->lhs()->accept(this);
|
||||
v->rhs()->accept(this);
|
||||
|
||||
const Expr* true_value = v->ret_val1();
|
||||
const Expr* false_value = v->ret_val2();
|
||||
Expr* true_value = v->ret_val1();
|
||||
Expr* false_value = v->ret_val2();
|
||||
|
||||
auto enclosingScope = currentScope_;
|
||||
|
||||
@ -1029,11 +1025,11 @@ void MemDependencyChecker::visit(const CompareSelect* v) {
|
||||
|
||||
// Inserts accesses for a map of buffers (ie. for inputs and outputs).
|
||||
void MemDependencyChecker::insertBuffers(
|
||||
std::unordered_map<const Buf*, std::shared_ptr<AccessInfo>>& bufs,
|
||||
std::unordered_map<Buf*, std::shared_ptr<AccessInfo>>& bufs,
|
||||
AccessType type) {
|
||||
for (auto& pair : bufs) {
|
||||
const Buf* b = pair.first;
|
||||
const Var* var = b->base_handle();
|
||||
Buf* b = pair.first;
|
||||
Var* var = b->base_handle();
|
||||
IndexBounds bounds;
|
||||
for (auto* d : b->dims()) {
|
||||
bounds.push_back(
|
||||
@ -1050,7 +1046,7 @@ void MemDependencyChecker::insertBuffers(
|
||||
}
|
||||
}
|
||||
|
||||
void MemDependencyChecker::visit(const Block* v) {
|
||||
void MemDependencyChecker::visit(Block* v) {
|
||||
auto prev_scope = currentScope_;
|
||||
|
||||
// handle kernel inputs.
|
||||
@ -1086,15 +1082,15 @@ void MemDependencyChecker::visit(const Block* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void MemDependencyChecker::visit(const Let* v) {
|
||||
const Stmt* last = lastStmt_;
|
||||
void MemDependencyChecker::visit(Let* v) {
|
||||
Stmt* last = lastStmt_;
|
||||
lastStmt_ = v;
|
||||
|
||||
IRVisitor::visit(v);
|
||||
|
||||
lastStmt_ = last;
|
||||
|
||||
const Var* var = v->var();
|
||||
Var* var = v->var();
|
||||
if (knownVarBounds_.count(var) != 0) {
|
||||
currentScope_->shadowedVarBounds[var] = knownVarBounds_[var];
|
||||
}
|
||||
@ -1105,17 +1101,17 @@ void MemDependencyChecker::visit(const Let* v) {
|
||||
|
||||
// Don't support AtomicAdd yet, it's a bit more complex since it's both a read
|
||||
// and a write. It's only inserted during Cuda codegen so this should be okay.
|
||||
void MemDependencyChecker::visit(const AtomicAdd* v) {
|
||||
void MemDependencyChecker::visit(AtomicAdd* v) {
|
||||
throw std::runtime_error("MemDependencyChecker AtomicAdd unimplemented");
|
||||
}
|
||||
|
||||
void MemDependencyChecker::visit(const Allocate* v) {
|
||||
const Stmt* last = lastStmt_;
|
||||
void MemDependencyChecker::visit(Allocate* v) {
|
||||
Stmt* last = lastStmt_;
|
||||
lastStmt_ = v;
|
||||
|
||||
IRVisitor::visit(v);
|
||||
|
||||
const Var* var = v->buffer_var();
|
||||
Var* var = v->buffer_var();
|
||||
IndexBounds bounds;
|
||||
// TODO: remove the "buf_flat_size" process below and extend the buf bound
|
||||
// check to support N-d indices access and 1-d index access.
|
||||
@ -1124,7 +1120,7 @@ void MemDependencyChecker::visit(const Allocate* v) {
|
||||
// identify 1-d index access for N-d bufs. Thus we flatten N-d bufs here to
|
||||
// avoid failing the bound check. But this is not the correct approach and
|
||||
// should be fixed.
|
||||
const Expr* flat_size = buf_flat_size(v->buf());
|
||||
Expr* flat_size = buf_flat_size(v->buf());
|
||||
flat_size = IRSimplifier::simplify(new Sub(flat_size, new IntImm(1)));
|
||||
bounds.push_back({new IntImm(0), flat_size});
|
||||
|
||||
@ -1140,13 +1136,13 @@ void MemDependencyChecker::visit(const Allocate* v) {
|
||||
lastStmt_ = last;
|
||||
}
|
||||
|
||||
void MemDependencyChecker::visit(const Free* v) {
|
||||
const Stmt* last = lastStmt_;
|
||||
void MemDependencyChecker::visit(Free* v) {
|
||||
Stmt* last = lastStmt_;
|
||||
lastStmt_ = v;
|
||||
|
||||
IRVisitor::visit(v);
|
||||
|
||||
const Var* var = v->buffer_var();
|
||||
Var* var = v->buffer_var();
|
||||
auto it = intermediates_.find(var);
|
||||
TORCH_INTERNAL_ASSERT(it != intermediates_.end());
|
||||
|
||||
@ -1247,7 +1243,7 @@ void MemDependencyChecker::mergeScope(
|
||||
|
||||
// Copy open writes up.
|
||||
for (auto& pair : child->openWrites_) {
|
||||
const Var* var = pair.first;
|
||||
Var* var = pair.first;
|
||||
|
||||
// Intentionally using operator[], we want it to be created if it does not
|
||||
// exist.
|
||||
@ -1270,7 +1266,7 @@ class VarBoundBinder : public IRVisitor {
|
||||
public:
|
||||
VarBoundBinder(const VarBoundMap& vars) : vars_(vars) {}
|
||||
|
||||
Bound getBounds(const Expr* e) {
|
||||
Bound getBounds(Expr* e) {
|
||||
min_ = e;
|
||||
max_ = e;
|
||||
e->accept(this);
|
||||
@ -1280,7 +1276,7 @@ class VarBoundBinder : public IRVisitor {
|
||||
}
|
||||
|
||||
private:
|
||||
void visit(const Var* v) override {
|
||||
void visit(Var* v) override {
|
||||
auto it = vars_.find(v);
|
||||
if (it == vars_.end()) {
|
||||
return;
|
||||
@ -1290,13 +1286,13 @@ class VarBoundBinder : public IRVisitor {
|
||||
max_ = Substitute(max_, {{v, it->second.end}});
|
||||
}
|
||||
|
||||
const Expr* min_{nullptr};
|
||||
const Expr* max_{nullptr};
|
||||
Expr* min_{nullptr};
|
||||
Expr* max_{nullptr};
|
||||
const VarBoundMap& vars_;
|
||||
};
|
||||
|
||||
std::vector<Bound> MemDependencyChecker::getIndicesBounds(
|
||||
const std::vector<const Expr*>& indices) {
|
||||
const std::vector<Expr*>& indices) {
|
||||
std::vector<Bound> bounds;
|
||||
bounds.reserve(indices.size());
|
||||
VarBoundBinder binder(knownVarBounds_);
|
||||
|
@ -40,8 +40,8 @@ class TORCH_API AccessInfo {
|
||||
AccessInfo(
|
||||
size_t id,
|
||||
AccessType type,
|
||||
const Stmt* stmt,
|
||||
const Var* var,
|
||||
Stmt* stmt,
|
||||
Var* var,
|
||||
IndexBounds bounds)
|
||||
: id_(id),
|
||||
type_(type),
|
||||
@ -53,9 +53,9 @@ class TORCH_API AccessInfo {
|
||||
AccessInfo(
|
||||
size_t id,
|
||||
AccessType type,
|
||||
const Expr* expr,
|
||||
const Stmt* stmt,
|
||||
const Var* var,
|
||||
Expr* expr,
|
||||
Stmt* stmt,
|
||||
Var* var,
|
||||
IndexBounds bounds)
|
||||
: id_(id),
|
||||
type_(type),
|
||||
@ -77,18 +77,18 @@ class TORCH_API AccessInfo {
|
||||
// The enclosing Stmt this access represents. E.g. if this is a Store then
|
||||
// Stmt is the Store itself, while if the access is caused by an Expr, this is
|
||||
// the most immediate parent Stmt.
|
||||
const Stmt* stmt() const {
|
||||
Stmt* stmt() const {
|
||||
return stmt_;
|
||||
}
|
||||
|
||||
// If the access is represented by an Expr (such as Load or Call) then this is
|
||||
// it, otherwise it's nullptr.
|
||||
const Expr* expr() const {
|
||||
Expr* expr() const {
|
||||
return expr_;
|
||||
}
|
||||
|
||||
// The Var representing the underlying Buffer.
|
||||
const Var* var() const {
|
||||
Var* var() const {
|
||||
return var_;
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ class TORCH_API AccessInfo {
|
||||
}
|
||||
|
||||
// Returns the symbolic expression of the indices of this access.
|
||||
std::vector<const Expr*> getIndices() const;
|
||||
std::vector<Expr*> getIndices() const;
|
||||
|
||||
// Establishes a dependency or dependent relationship with another access.
|
||||
void addDependency(const std::shared_ptr<AccessInfo>& write);
|
||||
@ -149,9 +149,9 @@ class TORCH_API AccessInfo {
|
||||
private:
|
||||
size_t id_;
|
||||
AccessType type_;
|
||||
const Stmt* stmt_;
|
||||
const Expr* expr_;
|
||||
const Var* var_;
|
||||
Stmt* stmt_;
|
||||
Expr* expr_;
|
||||
Var* var_;
|
||||
IndexBounds bounds_;
|
||||
|
||||
// Yes these should be sorted.
|
||||
@ -159,7 +159,7 @@ class TORCH_API AccessInfo {
|
||||
std::map<size_t, std::shared_ptr<AccessInfo>> dependents_;
|
||||
};
|
||||
|
||||
using VarBoundMap = std::unordered_map<const Var*, Bound>;
|
||||
using VarBoundMap = std::unordered_map<Var*, Bound>;
|
||||
|
||||
/* MemDepedencyChecker analyses a IR fragment and builds a dependency graph of
|
||||
* accesses contained within.
|
||||
@ -176,8 +176,8 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
public:
|
||||
MemDependencyChecker();
|
||||
MemDependencyChecker(
|
||||
const std::unordered_set<const Buf*>& inputs,
|
||||
const std::unordered_set<const Buf*>& outputs);
|
||||
const std::unordered_set<Buf*>& inputs,
|
||||
const std::unordered_set<Buf*>& outputs);
|
||||
MemDependencyChecker(
|
||||
const std::vector<BufHandle>& inputs,
|
||||
const std::vector<BufHandle>& outputs);
|
||||
@ -193,15 +193,15 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
// about it.
|
||||
|
||||
// Returns true if any read in A has a direct dependence on a write in B.
|
||||
bool dependsDirectly(const Stmt* A, const Stmt* B);
|
||||
bool dependsDirectly(const Expr* A, const Stmt* B);
|
||||
bool dependsDirectly(Stmt* A, Stmt* B);
|
||||
bool dependsDirectly(Expr* A, Stmt* B);
|
||||
|
||||
// Returns true of the output depends directly on a write contained in B.
|
||||
bool dependsDirectly(const Buf* output, const Stmt* B);
|
||||
bool dependsDirectly(Buf* output, Stmt* B);
|
||||
|
||||
// Returns true if a read in A depends directly on the provided input.
|
||||
bool dependsDirectly(const Stmt* A, const Buf* input);
|
||||
bool dependsDirectly(const Expr* A, const Buf* input);
|
||||
bool dependsDirectly(Stmt* A, Buf* input);
|
||||
bool dependsDirectly(Expr* A, Buf* input);
|
||||
|
||||
// Outputs/inputs cannot depend directly.
|
||||
|
||||
@ -211,18 +211,18 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
const std::shared_ptr<AccessInfo>& B);
|
||||
|
||||
// Returns true if any read in A has an ancestor write contained in B.
|
||||
bool dependsIndirectly(const Stmt* A, const Stmt* B);
|
||||
bool dependsIndirectly(const Expr* A, const Stmt* B);
|
||||
bool dependsIndirectly(Stmt* A, Stmt* B);
|
||||
bool dependsIndirectly(Expr* A, Stmt* B);
|
||||
|
||||
// Returns true of the output depends indirectly on a write contained in B.
|
||||
bool dependsIndirectly(const Buf* output, const Stmt* B);
|
||||
bool dependsIndirectly(Buf* output, Stmt* B);
|
||||
|
||||
// Returns true if a read in A depends indirectly on the provided input.
|
||||
bool dependsIndirectly(const Stmt* A, const Buf* input);
|
||||
bool dependsIndirectly(const Expr* A, const Buf* input);
|
||||
bool dependsIndirectly(Stmt* A, Buf* input);
|
||||
bool dependsIndirectly(Expr* A, Buf* input);
|
||||
|
||||
// returns true if the output uses any load of the input.
|
||||
bool dependsIndirectly(const Buf* output, const Buf* input);
|
||||
bool dependsIndirectly(Buf* output, Buf* input);
|
||||
|
||||
// Returns true if the access A has a dependency chain to access B.
|
||||
bool dependsIndirectly(
|
||||
@ -230,21 +230,19 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
const std::shared_ptr<AccessInfo>& B);
|
||||
|
||||
// Returns the AccessInfo
|
||||
std::shared_ptr<AccessInfo> accessFor(const Stmt* A) const;
|
||||
std::shared_ptr<AccessInfo> accessFor(const Expr* A) const;
|
||||
std::shared_ptr<AccessInfo> accessFor(Stmt* A) const;
|
||||
std::shared_ptr<AccessInfo> accessFor(Expr* A) const;
|
||||
|
||||
// Returns all AccessInfos.
|
||||
std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(
|
||||
const Stmt* A) const;
|
||||
std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(Stmt* A) const;
|
||||
// TODO: this will return only the AccessInfo for A. It's included for
|
||||
// completeness but be aware it wont return accesses used in the computation
|
||||
// of A.
|
||||
std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(
|
||||
const Expr* A) const;
|
||||
std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(Expr* A) const;
|
||||
|
||||
// Accesses relating to input and output buffers.
|
||||
std::shared_ptr<AccessInfo> input(const Buf* B) const;
|
||||
std::shared_ptr<AccessInfo> output(const Buf* B) const;
|
||||
std::shared_ptr<AccessInfo> input(Buf* B) const;
|
||||
std::shared_ptr<AccessInfo> output(Buf* B) const;
|
||||
|
||||
// Returns the full history of reads and writes.
|
||||
const std::vector<std::shared_ptr<AccessInfo>>& getHistory() const;
|
||||
@ -254,17 +252,17 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
|
||||
private:
|
||||
// Node visitors.
|
||||
void visit(const Store* v) override;
|
||||
void visit(const Load* v) override;
|
||||
void visit(const For* v) override;
|
||||
void visit(const Cond* v) override;
|
||||
void visit(const IfThenElse* v) override;
|
||||
void visit(const CompareSelect* v) override;
|
||||
void visit(const Block* v) override;
|
||||
void visit(const Let* v) override;
|
||||
void visit(const AtomicAdd* v) override;
|
||||
void visit(const Allocate* v) override;
|
||||
void visit(const Free* v) override;
|
||||
void visit(Store* v) override;
|
||||
void visit(Load* v) override;
|
||||
void visit(For* v) override;
|
||||
void visit(Cond* v) override;
|
||||
void visit(IfThenElse* v) override;
|
||||
void visit(CompareSelect* v) override;
|
||||
void visit(Block* v) override;
|
||||
void visit(Let* v) override;
|
||||
void visit(AtomicAdd* v) override;
|
||||
void visit(Allocate* v) override;
|
||||
void visit(Free* v) override;
|
||||
|
||||
using BoundRelationship = std::pair<IndexBounds, std::shared_ptr<AccessInfo>>;
|
||||
|
||||
@ -276,29 +274,27 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
Block* block;
|
||||
std::shared_ptr<Scope> parent;
|
||||
|
||||
std::unordered_map<const Var*, Bound> shadowedVarBounds;
|
||||
std::unordered_set<const Var*> localVars;
|
||||
std::unordered_map<Var*, Bound> shadowedVarBounds;
|
||||
std::unordered_set<Var*> localVars;
|
||||
|
||||
std::vector<std::shared_ptr<AccessInfo>> accesses_;
|
||||
|
||||
std::unordered_map<const Var*, std::list<BoundRelationship>> openWrites_;
|
||||
std::unordered_map<Var*, std::list<BoundRelationship>> openWrites_;
|
||||
};
|
||||
std::shared_ptr<Scope> currentScope_;
|
||||
|
||||
bool allowExecutionOrderAnalysis_{false};
|
||||
|
||||
std::unordered_multimap<const Stmt*, std::shared_ptr<AccessInfo>>
|
||||
stmtToAccess_;
|
||||
std::unordered_multimap<const Expr*, std::shared_ptr<AccessInfo>>
|
||||
exprToAccess_;
|
||||
std::unordered_map<const Stmt*, std::vector<std::shared_ptr<AccessInfo>>>
|
||||
std::unordered_multimap<Stmt*, std::shared_ptr<AccessInfo>> stmtToAccess_;
|
||||
std::unordered_multimap<Expr*, std::shared_ptr<AccessInfo>> exprToAccess_;
|
||||
std::unordered_map<Stmt*, std::vector<std::shared_ptr<AccessInfo>>>
|
||||
scopeToAccesses_;
|
||||
|
||||
VarBoundMap knownVarBounds_;
|
||||
|
||||
// Finds all accesses that are reads within the scope of v.
|
||||
template <typename StmtOrExpr>
|
||||
DependencySet getAllReadsWithin(const StmtOrExpr* v) {
|
||||
DependencySet getAllReadsWithin(StmtOrExpr* v) {
|
||||
DependencySet reads;
|
||||
auto insertAllReads = [&](const auto& nodes) {
|
||||
for (auto* l : nodes) {
|
||||
@ -321,7 +317,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
|
||||
// Finds all accesses that are writes within the scope of v.
|
||||
// Writes cannot occur in Exprs, so this is a little simpler.
|
||||
DependencySet getAllWritesWithin(const Stmt* v) {
|
||||
DependencySet getAllWritesWithin(Stmt* v) {
|
||||
DependencySet writes;
|
||||
|
||||
// writes just Store currently.
|
||||
@ -339,7 +335,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
|
||||
// Templated helpers to work on either Exprs or Stmts.
|
||||
template <typename StmtOrExpr>
|
||||
bool dependsDirectlyHelper(const StmtOrExpr* A, const Stmt* B) {
|
||||
bool dependsDirectlyHelper(StmtOrExpr* A, Stmt* B) {
|
||||
auto aReads = getAllReadsWithin(A);
|
||||
auto bWrites = getAllWritesWithin(B);
|
||||
|
||||
@ -355,7 +351,7 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
}
|
||||
|
||||
template <typename StmtOrExpr>
|
||||
bool dependsIndirectlyHelper(const StmtOrExpr* A, const Stmt* B) {
|
||||
bool dependsIndirectlyHelper(StmtOrExpr* A, Stmt* B) {
|
||||
auto aReads = getAllReadsWithin(A);
|
||||
auto bWrites = getAllWritesWithin(B);
|
||||
|
||||
@ -373,13 +369,13 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
DependencySet getAllWriteDependencies(const DependencySet& products);
|
||||
|
||||
// Maps for inputs and outputs, since they aren't present directly in the IR.
|
||||
std::unordered_map<const Buf*, std::shared_ptr<AccessInfo>> inputs_;
|
||||
std::unordered_map<const Buf*, std::shared_ptr<AccessInfo>> outputs_;
|
||||
std::unordered_map<const Var*, std::shared_ptr<AccessInfo>> intermediates_;
|
||||
std::unordered_map<Buf*, std::shared_ptr<AccessInfo>> inputs_;
|
||||
std::unordered_map<Buf*, std::shared_ptr<AccessInfo>> outputs_;
|
||||
std::unordered_map<Var*, std::shared_ptr<AccessInfo>> intermediates_;
|
||||
|
||||
// Inserts accesses for Buf's: specifically for inputs and outputs.
|
||||
void insertBuffers(
|
||||
std::unordered_map<const Buf*, std::shared_ptr<AccessInfo>>& bufs,
|
||||
std::unordered_map<Buf*, std::shared_ptr<AccessInfo>>& bufs,
|
||||
AccessType type);
|
||||
|
||||
// Update the write history with a new write, adding dependencies and closing
|
||||
@ -399,10 +395,10 @@ class TORCH_API MemDependencyChecker : public IRVisitor {
|
||||
bool closeOverlapped = true);
|
||||
|
||||
// Binds symbolic vars in indices with the low and high bound for those vars.
|
||||
std::vector<Bound> getIndicesBounds(const std::vector<const Expr*>& indices);
|
||||
std::vector<Bound> getIndicesBounds(const std::vector<Expr*>& indices);
|
||||
|
||||
size_t nextAccess_{0};
|
||||
const Stmt* lastStmt_{nullptr};
|
||||
Stmt* lastStmt_{nullptr};
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
|
@ -7,19 +7,19 @@ namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
ReduceOp* Reducer::operator()(
|
||||
const Buf* result_buf,
|
||||
Buf* result_buf,
|
||||
ExprHandle body,
|
||||
const std::vector<const Expr*>& output,
|
||||
const std::vector<const Var*>& inner) const {
|
||||
const std::vector<Expr*>& output,
|
||||
const std::vector<Var*>& inner) const {
|
||||
return new ReduceOp(
|
||||
complete(result_buf, interaction_, body, output, inner), inner, *this);
|
||||
}
|
||||
|
||||
ReduceOp* Reducer::operator()(
|
||||
const Buf* result_buf,
|
||||
const Expr* body,
|
||||
const std::vector<const Expr*>& output,
|
||||
const std::vector<const Var*>& inner) const {
|
||||
Buf* result_buf,
|
||||
Expr* body,
|
||||
const std::vector<Expr*>& output,
|
||||
const std::vector<Var*>& inner) const {
|
||||
return new ReduceOp(
|
||||
complete(result_buf, interaction_, ExprHandle(body), output, inner),
|
||||
inner,
|
||||
|
@ -35,21 +35,21 @@ class TORCH_API Reducer {
|
||||
}
|
||||
virtual ~Reducer() = default;
|
||||
|
||||
const Expr* initializer() const {
|
||||
Expr* initializer() const {
|
||||
return init_;
|
||||
}
|
||||
|
||||
ReduceOp* operator()(
|
||||
const Buf* result_buf,
|
||||
Buf* result_buf,
|
||||
ExprHandle body,
|
||||
const std::vector<const Expr*>& output,
|
||||
const std::vector<const Var*>& inner) const;
|
||||
const std::vector<Expr*>& output,
|
||||
const std::vector<Var*>& inner) const;
|
||||
|
||||
ReduceOp* operator()(
|
||||
const Buf* result_buf,
|
||||
const Expr* body,
|
||||
const std::vector<const Expr*>& output,
|
||||
const std::vector<const Var*>& inner) const;
|
||||
Buf* result_buf,
|
||||
Expr* body,
|
||||
const std::vector<Expr*>& output,
|
||||
const std::vector<Var*>& inner) const;
|
||||
|
||||
// Polymorphic handling of Body functions with a variety of parameters.
|
||||
static ExprHandle getReduceBody(
|
||||
@ -104,11 +104,11 @@ class TORCH_API Reducer {
|
||||
// Completes the reduction operator by applying the interaction function to
|
||||
// the accumulation and the body expression.
|
||||
static Expr* complete(
|
||||
const Buf* accumulator,
|
||||
Buf* accumulator,
|
||||
ReduceInteraction interaction,
|
||||
ExprHandle body,
|
||||
const std::vector<const Expr*>& output_args,
|
||||
const std::vector<const Var*>& reduce_args) {
|
||||
const std::vector<Expr*>& output_args,
|
||||
const std::vector<Var*>& reduce_args) {
|
||||
ExprHandle accum =
|
||||
ExprHandle(new Load(body.dtype(), accumulator, output_args));
|
||||
auto e = interaction(accum, body);
|
||||
@ -116,7 +116,7 @@ class TORCH_API Reducer {
|
||||
}
|
||||
|
||||
private:
|
||||
const Expr* init_;
|
||||
Expr* init_;
|
||||
ReduceInteraction interaction_;
|
||||
};
|
||||
|
||||
@ -128,17 +128,14 @@ class TORCH_API Reducer {
|
||||
class TORCH_API ReduceOp : public ExprNode<ReduceOp> {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
ReduceOp(
|
||||
const Expr* body,
|
||||
std::vector<const Var*> reduce_args,
|
||||
const Reducer& reducer)
|
||||
ReduceOp(Expr* body, std::vector<Var*> reduce_args, const Reducer& reducer)
|
||||
: ExprNodeBase(body->dtype()),
|
||||
body_(body),
|
||||
reduce_args_(std::move(reduce_args)),
|
||||
reducer_(reducer) {}
|
||||
|
||||
// return the body expression which obtains the value to be reduced.
|
||||
const Expr* body() const {
|
||||
Expr* body() const {
|
||||
return body_;
|
||||
}
|
||||
|
||||
@ -148,13 +145,13 @@ class TORCH_API ReduceOp : public ExprNode<ReduceOp> {
|
||||
}
|
||||
|
||||
// returns variables associated with the axes of reduction.
|
||||
const std::vector<const Var*>& reduce_args() const {
|
||||
const std::vector<Var*>& reduce_args() const {
|
||||
return reduce_args_;
|
||||
}
|
||||
|
||||
private:
|
||||
const Expr* body_;
|
||||
std::vector<const Var*> reduce_args_;
|
||||
Expr* body_;
|
||||
std::vector<Var*> reduce_args_;
|
||||
const Reducer reducer_;
|
||||
};
|
||||
|
||||
@ -223,7 +220,7 @@ class ReductionExpander : public IRMutator {
|
||||
return s->accept_mutator(this);
|
||||
}
|
||||
|
||||
const Expr* mutate(const ReduceOp* v) override {
|
||||
Expr* mutate(ReduceOp* v) override {
|
||||
return v->body();
|
||||
}
|
||||
};
|
||||
|
@ -7,9 +7,7 @@ namespace registerizer {
|
||||
|
||||
// AccessInfo
|
||||
|
||||
void AccessInfo::addStore(
|
||||
const Store* store,
|
||||
const std::shared_ptr<Scope>& scope) {
|
||||
void AccessInfo::addStore(Store* store, const std::shared_ptr<Scope>& scope) {
|
||||
block_ =
|
||||
block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
|
||||
|
||||
@ -27,9 +25,9 @@ void AccessInfo::addStore(
|
||||
}
|
||||
|
||||
void AccessInfo::addLoad(
|
||||
const Load* load,
|
||||
Load* load,
|
||||
const std::shared_ptr<Scope>& scope,
|
||||
const Stmt* usage) {
|
||||
Stmt* usage) {
|
||||
block_ =
|
||||
block_ ? Block::getSharedParent(block_, scope->block()) : scope->block();
|
||||
first_usage_ = first_usage_ ? block_->getEnclosedRoot(first_usage_) : usage;
|
||||
@ -69,13 +67,13 @@ bool AccessInfo::overlaps(const std::shared_ptr<AccessInfo>& other) {
|
||||
// All accesses to a buf must have the same dimensionality.
|
||||
TORCH_INTERNAL_ASSERT(indices_.size() == other->indices().size());
|
||||
|
||||
const auto& other_indices = other->indices();
|
||||
auto& other_indices = other->indices();
|
||||
|
||||
// They don't overlap if there is a guaranteed difference in any
|
||||
// dimension.
|
||||
bool overlap = true;
|
||||
for (size_t i = 0; i < indices_.size(); ++i) {
|
||||
const Expr* diff = new Sub(indices_[i], other_indices[i]);
|
||||
Expr* diff = new Sub(indices_[i], other_indices[i]);
|
||||
diff = IRSimplifier::simplify(diff);
|
||||
|
||||
if (diff->isConstant() && !immediateEquals(diff, 0)) {
|
||||
@ -87,7 +85,7 @@ bool AccessInfo::overlaps(const std::shared_ptr<AccessInfo>& other) {
|
||||
return overlap;
|
||||
}
|
||||
|
||||
bool AccessInfo::dependsOnVar(const Var* v) {
|
||||
bool AccessInfo::dependsOnVar(Var* v) {
|
||||
VarFinder vf;
|
||||
for (auto* i : indices_) {
|
||||
i->accept(&vf);
|
||||
@ -139,7 +137,7 @@ void Scope::closeAccess(const std::shared_ptr<AccessInfo>& info) {
|
||||
closedAccesses_.push_back(info);
|
||||
}
|
||||
|
||||
AccessHashMap& Scope::getAccessMapByBuf(const Buf* b) {
|
||||
AccessHashMap& Scope::getAccessMapByBuf(Buf* b) {
|
||||
auto it = openAccesses_.find(b);
|
||||
if (it == openAccesses_.end()) {
|
||||
// create and return
|
||||
@ -179,7 +177,7 @@ void RegisterizerAnalysis::closeAccessIntoScope(
|
||||
scope->closeAccess(info);
|
||||
}
|
||||
|
||||
void RegisterizerAnalysis::visit(const For* v) {
|
||||
void RegisterizerAnalysis::visit(For* v) {
|
||||
if (v->loop_options().is_gpu_block_index() ||
|
||||
v->loop_options().is_gpu_thread_index()) {
|
||||
throw malformed_input(
|
||||
@ -195,8 +193,7 @@ void RegisterizerAnalysis::visit(const For* v) {
|
||||
v->body()->accept(this);
|
||||
stmtStack_.pop_front();
|
||||
|
||||
const Expr* loopExtent =
|
||||
IRSimplifier::simplify(new Sub(v->stop(), v->start()));
|
||||
Expr* loopExtent = IRSimplifier::simplify(new Sub(v->stop(), v->start()));
|
||||
|
||||
// now we need to see which accesses we can hoist out of the for loop, their
|
||||
// costs should be multiplied by the loop extent.
|
||||
@ -263,8 +260,8 @@ void RegisterizerAnalysis::visit(const For* v) {
|
||||
mergeCurrentScopeIntoParent();
|
||||
};
|
||||
|
||||
void RegisterizerAnalysis::visit(const Cond* v) {
|
||||
const Expr* condition = v->condition();
|
||||
void RegisterizerAnalysis::visit(Cond* v) {
|
||||
Expr* condition = v->condition();
|
||||
Block* true_stmt = v->true_stmt();
|
||||
Block* false_stmt = v->false_stmt();
|
||||
|
||||
@ -303,10 +300,10 @@ void RegisterizerAnalysis::visit(const Cond* v) {
|
||||
// IfThenElses are just like Conds except they are not Stmts, which means no
|
||||
// registerization can occur internally. However, the first reference to an
|
||||
// access can occur within one if its visible outside the condition.
|
||||
void RegisterizerAnalysis::visit(const IfThenElse* v) {
|
||||
const Expr* condition = v->condition();
|
||||
const Expr* true_value = v->true_value();
|
||||
const Expr* false_value = v->false_value();
|
||||
void RegisterizerAnalysis::visit(IfThenElse* v) {
|
||||
Expr* condition = v->condition();
|
||||
Expr* true_value = v->true_value();
|
||||
Expr* false_value = v->false_value();
|
||||
|
||||
// condition is in enclosing scope.
|
||||
condition->accept(this);
|
||||
@ -338,7 +335,7 @@ void RegisterizerAnalysis::visit(const IfThenElse* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterizerAnalysis::visit(const Let* v) {
|
||||
void RegisterizerAnalysis::visit(Let* v) {
|
||||
currentScope_->addLocalVar(v->var());
|
||||
|
||||
stmtStack_.push_front(v);
|
||||
@ -346,7 +343,7 @@ void RegisterizerAnalysis::visit(const Let* v) {
|
||||
stmtStack_.pop_front();
|
||||
}
|
||||
|
||||
void RegisterizerAnalysis::visit(const Block* v) {
|
||||
void RegisterizerAnalysis::visit(Block* v) {
|
||||
auto prev_scope = currentScope_;
|
||||
if (currentScope_->block() != v) {
|
||||
currentScope_ = std::make_shared<Scope>(v, prev_scope);
|
||||
@ -374,7 +371,7 @@ void RegisterizerAnalysis::visit(const Block* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterizerAnalysis::visit(const Store* v) {
|
||||
void RegisterizerAnalysis::visit(Store* v) {
|
||||
stmtStack_.push_front(v);
|
||||
v->value()->accept(this);
|
||||
stmtStack_.pop_front();
|
||||
@ -428,7 +425,7 @@ void RegisterizerAnalysis::visit(const Store* v) {
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterizerAnalysis::visit(const Load* v) {
|
||||
void RegisterizerAnalysis::visit(Load* v) {
|
||||
if (v->indices().empty()) {
|
||||
// already a scalar.
|
||||
return;
|
||||
@ -563,7 +560,7 @@ void RegisterizerAnalysis::mergeCurrentScopeIntoParent() {
|
||||
// copy across current open accesses, merging as necessary.
|
||||
// for each Buf with an open access:
|
||||
for (auto& pair : currentScope_->openAccesses()) {
|
||||
const Buf* buf = pair.first;
|
||||
Buf* buf = pair.first;
|
||||
if (pair.second.empty()) {
|
||||
continue;
|
||||
}
|
||||
@ -640,7 +637,7 @@ std::vector<std::shared_ptr<AccessInfo>> RegisterizerAnalysis::getCandidates() {
|
||||
|
||||
// RegisterizerReplacer
|
||||
|
||||
const Expr* RegisterizerReplacer::mutate(const Load* v) {
|
||||
Expr* RegisterizerReplacer::mutate(Load* v) {
|
||||
auto it = loadToAccess_.find(v);
|
||||
if (it == loadToAccess_.end()) {
|
||||
// This access cannot be registerized.
|
||||
@ -652,7 +649,7 @@ const Expr* RegisterizerReplacer::mutate(const Load* v) {
|
||||
return info->replacement().var;
|
||||
}
|
||||
|
||||
Stmt* RegisterizerReplacer::mutate(const Store* v) {
|
||||
Stmt* RegisterizerReplacer::mutate(Store* v) {
|
||||
if (eliminatedIntializers_.count(v) != 0) {
|
||||
// This store is the intializer for a scalar var that is already inserted.
|
||||
return nullptr;
|
||||
@ -666,12 +663,12 @@ Stmt* RegisterizerReplacer::mutate(const Store* v) {
|
||||
|
||||
auto& info = it->second;
|
||||
|
||||
const Expr* new_val = v->value()->accept_mutator(this);
|
||||
Expr* new_val = v->value()->accept_mutator(this);
|
||||
|
||||
return new Store(info->replacement().var_wrapper, {}, new_val);
|
||||
}
|
||||
|
||||
Stmt* RegisterizerReplacer::mutate(const Block* v) {
|
||||
Stmt* RegisterizerReplacer::mutate(Block* v) {
|
||||
auto& scope = parentToAccesses_[v];
|
||||
|
||||
std::vector<Stmt*> stmts;
|
||||
|
@ -23,7 +23,7 @@ For example it can replace:
|
||||
|
||||
{
|
||||
A[0] = 0;
|
||||
for(const auto x : c10::irange(10)) {
|
||||
for(auto x : c10::irange(10)) {
|
||||
A[0] = (A[0]) + x;
|
||||
}
|
||||
}
|
||||
@ -32,7 +32,7 @@ with:
|
||||
|
||||
{
|
||||
int A_ = 0;
|
||||
for(const auto x : c10::irange(10)) {
|
||||
for(auto x : c10::irange(10)) {
|
||||
A_ = x + A_;
|
||||
}
|
||||
A[0] = A_;
|
||||
@ -54,8 +54,8 @@ class AccessInfo {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AccessInfo(
|
||||
SimplifierHashType h,
|
||||
const Buf* b,
|
||||
std::vector<const Expr*> i,
|
||||
Buf* b,
|
||||
std::vector<Expr*> i,
|
||||
size_t accessOrder)
|
||||
: hash_(h),
|
||||
buf_(b),
|
||||
@ -65,14 +65,11 @@ class AccessInfo {
|
||||
accessOrder_(accessOrder) {}
|
||||
|
||||
// Adds a Store to this access, which is in the provided scope.
|
||||
void addStore(const Store* store, const std::shared_ptr<Scope>& scope);
|
||||
void addStore(Store* store, const std::shared_ptr<Scope>& scope);
|
||||
|
||||
// Adds a Load to this access, which occurs in the usage Stmt in the provided
|
||||
// scope.
|
||||
void addLoad(
|
||||
const Load* load,
|
||||
const std::shared_ptr<Scope>& scope,
|
||||
const Stmt* usage);
|
||||
void addLoad(Load* load, const std::shared_ptr<Scope>& scope, Stmt* usage);
|
||||
|
||||
// Merge another AccessInfo into this one.
|
||||
void merge(const std::shared_ptr<AccessInfo>& other);
|
||||
@ -81,7 +78,7 @@ class AccessInfo {
|
||||
bool overlaps(const std::shared_ptr<AccessInfo>& other);
|
||||
|
||||
// Returns true if the indices of this access depend on the provided Var.
|
||||
bool dependsOnVar(const Var* v);
|
||||
bool dependsOnVar(Var* v);
|
||||
|
||||
// Clone this AccessInfo, and set this as the new accesses' hiddenAccess.
|
||||
static std::shared_ptr<AccessInfo> cloneWithHiddenInfo(
|
||||
@ -94,30 +91,30 @@ class AccessInfo {
|
||||
return hash_;
|
||||
}
|
||||
|
||||
const Buf* buf() const {
|
||||
Buf* buf() const {
|
||||
return buf_;
|
||||
}
|
||||
|
||||
const std::vector<const Expr*>& indices() const {
|
||||
const std::vector<Expr*>& indices() const {
|
||||
return indices_;
|
||||
}
|
||||
|
||||
const Block* block() const {
|
||||
Block* block() const {
|
||||
return block_;
|
||||
}
|
||||
|
||||
void setEnclosingBlock(const Block* b) {
|
||||
void setEnclosingBlock(Block* b) {
|
||||
block_ = b;
|
||||
}
|
||||
|
||||
const Stmt* first_usage() const {
|
||||
Stmt* first_usage() const {
|
||||
return first_usage_;
|
||||
}
|
||||
const Stmt* last_usage() const {
|
||||
Stmt* last_usage() const {
|
||||
return last_usage_;
|
||||
}
|
||||
|
||||
void setUsageMarks(const Stmt* first, const Stmt* last) {
|
||||
void setUsageMarks(Stmt* first, Stmt* last) {
|
||||
first_usage_ = first;
|
||||
last_usage_ = last;
|
||||
}
|
||||
@ -126,23 +123,23 @@ class AccessInfo {
|
||||
return firstUsageOverlapped_;
|
||||
}
|
||||
|
||||
const Expr* store_cost() const {
|
||||
Expr* store_cost() const {
|
||||
return store_cost_;
|
||||
}
|
||||
|
||||
const Expr* load_cost() const {
|
||||
Expr* load_cost() const {
|
||||
return load_cost_;
|
||||
}
|
||||
|
||||
const std::vector<const Store*>& stores() const {
|
||||
const std::vector<Store*>& stores() const {
|
||||
return stores_;
|
||||
}
|
||||
|
||||
const std::vector<const Load*>& loads() const {
|
||||
const std::vector<Load*>& loads() const {
|
||||
return loads_;
|
||||
}
|
||||
|
||||
void hoistCosts(const Expr* extent) {
|
||||
void hoistCosts(Expr* extent) {
|
||||
store_cost_ = IRSimplifier::simplify(new Mul(store_cost_, extent));
|
||||
load_cost_ = IRSimplifier::simplify(new Mul(load_cost_, extent));
|
||||
}
|
||||
@ -177,12 +174,12 @@ class AccessInfo {
|
||||
|
||||
private:
|
||||
SimplifierHashType hash_;
|
||||
const Buf* buf_;
|
||||
std::vector<const Expr*> indices_;
|
||||
const Block* block_{nullptr};
|
||||
Buf* buf_;
|
||||
std::vector<Expr*> indices_;
|
||||
Block* block_{nullptr};
|
||||
|
||||
const Stmt* first_usage_{nullptr};
|
||||
const Stmt* last_usage_{nullptr};
|
||||
Stmt* first_usage_{nullptr};
|
||||
Stmt* last_usage_{nullptr};
|
||||
|
||||
// Whether or not this access is overlapped in the first Stmt it appears. This
|
||||
// means we cannot use it's first Store as the initializer.
|
||||
@ -190,13 +187,13 @@ class AccessInfo {
|
||||
|
||||
// The cost in real ops that this access represents, to enable
|
||||
// filtering accesses that wont save any loads or stores.
|
||||
const Expr* store_cost_;
|
||||
const Expr* load_cost_;
|
||||
Expr* store_cost_;
|
||||
Expr* load_cost_;
|
||||
|
||||
// The actual Stores and Loads which represent this access.
|
||||
// Be careful with these, any mutator will invalidate these pointers.
|
||||
std::vector<const Store*> stores_;
|
||||
std::vector<const Load*> loads_;
|
||||
std::vector<Store*> stores_;
|
||||
std::vector<Load*> loads_;
|
||||
|
||||
// An identifier representing the conditional block, if any, this access
|
||||
// depends on.
|
||||
@ -222,12 +219,12 @@ using AccessHashMap =
|
||||
class Scope {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Scope(const Block* b, std::shared_ptr<Scope> parent, size_t conditionId = 0)
|
||||
Scope(Block* b, std::shared_ptr<Scope> parent, size_t conditionId = 0)
|
||||
: block_(b), parent_(std::move(parent)), conditionId_(conditionId) {}
|
||||
|
||||
AccessHashMap& getAccessMapByBuf(const Buf* b);
|
||||
AccessHashMap& getAccessMapByBuf(Buf* b);
|
||||
|
||||
std::unordered_map<const Buf*, AccessHashMap>& openAccesses() {
|
||||
std::unordered_map<Buf*, AccessHashMap>& openAccesses() {
|
||||
return openAccesses_;
|
||||
}
|
||||
|
||||
@ -235,7 +232,7 @@ class Scope {
|
||||
return closedAccesses_;
|
||||
}
|
||||
|
||||
const Block* block() const {
|
||||
Block* block() const {
|
||||
return block_;
|
||||
}
|
||||
|
||||
@ -247,10 +244,10 @@ class Scope {
|
||||
return conditionId_;
|
||||
}
|
||||
|
||||
const std::unordered_set<const Var*>& localVars() const {
|
||||
const std::unordered_set<Var*>& localVars() const {
|
||||
return localVars_;
|
||||
}
|
||||
void addLocalVar(const Var* v) {
|
||||
void addLocalVar(Var* v) {
|
||||
localVars_.insert(v);
|
||||
}
|
||||
|
||||
@ -264,11 +261,11 @@ class Scope {
|
||||
// overlap with other accesses to the same buf. Buf ->
|
||||
// Hash ->
|
||||
// Access
|
||||
std::unordered_map<const Buf*, AccessHashMap> openAccesses_;
|
||||
std::unordered_map<Buf*, AccessHashMap> openAccesses_;
|
||||
std::vector<std::shared_ptr<AccessInfo>> closedAccesses_;
|
||||
|
||||
// The Block object this scope represents.
|
||||
const Block* block_;
|
||||
Block* block_;
|
||||
|
||||
// The enclosing scope object.
|
||||
std::shared_ptr<Scope> parent_;
|
||||
@ -277,7 +274,7 @@ class Scope {
|
||||
size_t conditionId_;
|
||||
|
||||
// A set of variables local to this scope (e.g. loop vars).
|
||||
std::unordered_set<const Var*> localVars_;
|
||||
std::unordered_set<Var*> localVars_;
|
||||
};
|
||||
|
||||
/* Analyzes the graph and collects accesses to the same symbolic tensor element
|
||||
@ -323,25 +320,25 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor {
|
||||
: currentScope_(std::make_shared<Scope>(nullptr, nullptr, 0)) {}
|
||||
~RegisterizerAnalysis() override = default;
|
||||
|
||||
void visit(const For* v) override;
|
||||
void visit(For* v) override;
|
||||
|
||||
void visit(const Cond* v) override;
|
||||
void visit(Cond* v) override;
|
||||
|
||||
void visit(const Block* v) override;
|
||||
void visit(Block* v) override;
|
||||
|
||||
void visit(const Store* v) override;
|
||||
void visit(Store* v) override;
|
||||
|
||||
void visit(const Load* v) override;
|
||||
void visit(Load* v) override;
|
||||
|
||||
void visit(const IfThenElse* v) override;
|
||||
void visit(IfThenElse* v) override;
|
||||
|
||||
void visit(const Let* v) override;
|
||||
void visit(Let* v) override;
|
||||
|
||||
#define STMT_ON_STACK(Op) \
|
||||
void visit(const Op* v) override { \
|
||||
stmtStack_.push_front(v); \
|
||||
IRVisitor::visit(v); \
|
||||
stmtStack_.pop_front(); \
|
||||
#define STMT_ON_STACK(Op) \
|
||||
void visit(Op* v) override { \
|
||||
stmtStack_.push_front(v); \
|
||||
IRVisitor::visit(v); \
|
||||
stmtStack_.pop_front(); \
|
||||
}
|
||||
|
||||
STMT_ON_STACK(AtomicAdd);
|
||||
@ -362,7 +359,7 @@ class TORCH_API RegisterizerAnalysis : public IRVisitor {
|
||||
std::unordered_set<size_t> exprConditionals_;
|
||||
|
||||
// A stack of enclosing Stmts for tracking the usage Stmt of Loads.
|
||||
std::deque<const Stmt*> stmtStack_;
|
||||
std::deque<Stmt*> stmtStack_;
|
||||
|
||||
// The current scope being analyzed.
|
||||
std::shared_ptr<Scope> currentScope_;
|
||||
@ -384,17 +381,17 @@ class TORCH_API RegisterizerReplacer : public IRMutator {
|
||||
buildReplacements();
|
||||
}
|
||||
|
||||
const Expr* mutate(const Load* v) override;
|
||||
Expr* mutate(Load* v) override;
|
||||
|
||||
Stmt* mutate(const Store* v) override;
|
||||
Stmt* mutate(Store* v) override;
|
||||
|
||||
Stmt* mutate(const Block* v) override;
|
||||
Stmt* mutate(Block* v) override;
|
||||
|
||||
private:
|
||||
struct ReplacerScope {
|
||||
std::unordered_map<const Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
|
||||
std::unordered_map<Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
|
||||
initializerPoints_;
|
||||
std::unordered_map<const Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
|
||||
std::unordered_map<Stmt*, std::deque<std::shared_ptr<AccessInfo>>>
|
||||
finalizePoints_;
|
||||
};
|
||||
|
||||
@ -403,18 +400,18 @@ class TORCH_API RegisterizerReplacer : public IRMutator {
|
||||
|
||||
// State relating to the accesses yet to be replaced.
|
||||
std::vector<std::shared_ptr<AccessInfo>>& infoSet_;
|
||||
std::unordered_map<const Store*, std::shared_ptr<AccessInfo>> storeToAccess_;
|
||||
std::unordered_map<const Load*, std::shared_ptr<AccessInfo>> loadToAccess_;
|
||||
std::unordered_map<const Block*, ReplacerScope> parentToAccesses_;
|
||||
std::unordered_map<Store*, std::shared_ptr<AccessInfo>> storeToAccess_;
|
||||
std::unordered_map<Load*, std::shared_ptr<AccessInfo>> loadToAccess_;
|
||||
std::unordered_map<Block*, ReplacerScope> parentToAccesses_;
|
||||
|
||||
// Holds the set of Stores that should be pulled into an initializer, so they
|
||||
// can be eliminated.
|
||||
std::set<const Store*> eliminatedIntializers_;
|
||||
std::set<Store*> eliminatedIntializers_;
|
||||
|
||||
// Tracks the number of times we've seen each buffer, so we can name the
|
||||
// scalar Vars appropriately.
|
||||
std::unordered_map<const Buf*, unsigned int> bufferAccessCounts_;
|
||||
unsigned int getBufferAccessCount(const Buf* b) {
|
||||
std::unordered_map<Buf*, unsigned int> bufferAccessCounts_;
|
||||
unsigned int getBufferAccessCount(Buf* b) {
|
||||
return ++bufferAccessCounts_[b];
|
||||
}
|
||||
};
|
||||
|
@ -17,7 +17,7 @@ class Placeholder;
|
||||
class TORCH_API Stmt : public KernelScopedObject {
|
||||
public:
|
||||
Stmt() = default;
|
||||
virtual void accept(IRVisitor* visitor) const = 0;
|
||||
virtual void accept(IRVisitor* visitor) = 0;
|
||||
virtual Stmt* accept_mutator(IRMutator* mutator) = 0;
|
||||
|
||||
Stmt* get_parent() const {
|
||||
@ -46,8 +46,8 @@ template <class Op>
|
||||
class StmtNode : public Stmt {
|
||||
public:
|
||||
using StmtNodeBase = StmtNode<Op>;
|
||||
void accept(IRVisitor* visitor) const override {
|
||||
visitor->visit(static_cast<const Op*>(this));
|
||||
void accept(IRVisitor* visitor) override {
|
||||
visitor->visit(static_cast<Op*>(this));
|
||||
}
|
||||
Stmt* accept_mutator(IRMutator* mutator) override;
|
||||
StmtNode() = default;
|
||||
@ -102,7 +102,7 @@ class TORCH_API Block : public StmtNode<Block> {
|
||||
set_parent(s, this);
|
||||
}
|
||||
|
||||
void insert_stmt_before(Stmt* s, const Stmt* before) {
|
||||
void insert_stmt_before(Stmt* s, Stmt* before) {
|
||||
if (s->get_parent()) {
|
||||
throw malformed_input("Block append Stmt with existing parent", s);
|
||||
}
|
||||
@ -117,7 +117,7 @@ class TORCH_API Block : public StmtNode<Block> {
|
||||
set_parent(s, this);
|
||||
}
|
||||
|
||||
void insert_stmt_after(Stmt* s, const Stmt* after) {
|
||||
void insert_stmt_after(Stmt* s, Stmt* after) {
|
||||
if (s->get_parent()) {
|
||||
throw malformed_input("Block append Stmt with existing parent", s);
|
||||
}
|
||||
@ -240,7 +240,7 @@ class TORCH_API Block : public StmtNode<Block> {
|
||||
return stmts_.front();
|
||||
}
|
||||
|
||||
const Stmt* front() const {
|
||||
Stmt* front() const {
|
||||
return stmts_.front();
|
||||
}
|
||||
|
||||
@ -248,7 +248,7 @@ class TORCH_API Block : public StmtNode<Block> {
|
||||
return stmts_.back();
|
||||
}
|
||||
|
||||
const Stmt* back() const {
|
||||
Stmt* back() const {
|
||||
return stmts_.back();
|
||||
}
|
||||
|
||||
@ -260,13 +260,13 @@ class TORCH_API Block : public StmtNode<Block> {
|
||||
stmts_.splice(it, other->stmts_);
|
||||
}
|
||||
|
||||
static const Block* getSharedParent(const Stmt* p1, const Stmt* p2) {
|
||||
static Block* getSharedParent(Stmt* p1, Stmt* p2) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::unordered_set<const Block*> enclosing;
|
||||
std::unordered_set<Block*> enclosing;
|
||||
|
||||
const Stmt* p1_p = p1;
|
||||
Stmt* p1_p = p1;
|
||||
while (p1_p) {
|
||||
if (const Block* b = dynamic_cast<const Block*>(p1_p)) {
|
||||
if (Block* b = dynamic_cast<Block*>(p1_p)) {
|
||||
if (b) {
|
||||
enclosing.insert(b);
|
||||
}
|
||||
@ -274,9 +274,9 @@ class TORCH_API Block : public StmtNode<Block> {
|
||||
p1_p = p1_p->get_parent();
|
||||
}
|
||||
|
||||
const Stmt* p2_p = p2;
|
||||
Stmt* p2_p = p2;
|
||||
while (p2_p) {
|
||||
if (const Block* b = dynamic_cast<const Block*>(p2_p)) {
|
||||
if (Block* b = dynamic_cast<Block*>(p2_p)) {
|
||||
if (enclosing.count(b) != 0) {
|
||||
return b;
|
||||
}
|
||||
@ -288,7 +288,7 @@ class TORCH_API Block : public StmtNode<Block> {
|
||||
}
|
||||
|
||||
// returns the immediate child containing statement s.
|
||||
const Stmt* getEnclosedRoot(const Stmt* s) const {
|
||||
Stmt* getEnclosedRoot(Stmt* s) const {
|
||||
while (s && s->get_parent() != this) {
|
||||
s = s->get_parent();
|
||||
}
|
||||
@ -301,20 +301,20 @@ class TORCH_API Block : public StmtNode<Block> {
|
||||
|
||||
class TORCH_API Store : public StmtNode<Store> {
|
||||
public:
|
||||
const Var* base_handle() const {
|
||||
Var* base_handle() const {
|
||||
return buf_->base_handle();
|
||||
}
|
||||
std::vector<const Expr*> indices() const {
|
||||
std::vector<Expr*> indices() const {
|
||||
return indices_;
|
||||
}
|
||||
const Expr* flat_index() const {
|
||||
Expr* flat_index() const {
|
||||
TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
|
||||
return indices_[0];
|
||||
}
|
||||
const Expr* value() const {
|
||||
Expr* value() const {
|
||||
return value_;
|
||||
}
|
||||
const Buf* buf() const {
|
||||
Buf* buf() const {
|
||||
return buf_;
|
||||
}
|
||||
|
||||
@ -323,16 +323,16 @@ class TORCH_API Store : public StmtNode<Store> {
|
||||
const std::vector<ExprHandle>& indices,
|
||||
const ExprHandle& value);
|
||||
|
||||
Store(const Buf* buf, std::vector<const Expr*> indices, const Expr* value);
|
||||
Store(Buf* buf, std::vector<Expr*> indices, Expr* value);
|
||||
|
||||
void set_indices(std::vector<const Expr*> indices) {
|
||||
void set_indices(std::vector<Expr*> indices) {
|
||||
indices_ = indices;
|
||||
};
|
||||
|
||||
private:
|
||||
const Buf* buf_;
|
||||
std::vector<const Expr*> indices_;
|
||||
const Expr* value_;
|
||||
Buf* buf_;
|
||||
std::vector<Expr*> indices_;
|
||||
Expr* value_;
|
||||
};
|
||||
|
||||
// Allocate a buffer of given shapes and dtypes and bind it with the given
|
||||
@ -344,7 +344,7 @@ class TORCH_API Allocate : public StmtNode<Allocate> {
|
||||
return new Allocate(buf_handle.node());
|
||||
}
|
||||
|
||||
const Var* buffer_var() const {
|
||||
Var* buffer_var() const {
|
||||
return buf_->base_handle();
|
||||
}
|
||||
|
||||
@ -352,18 +352,18 @@ class TORCH_API Allocate : public StmtNode<Allocate> {
|
||||
return buf_->dtype();
|
||||
}
|
||||
|
||||
const std::vector<const Expr*> dims() const {
|
||||
const std::vector<Expr*> dims() const {
|
||||
return buf_->dims();
|
||||
}
|
||||
|
||||
const Buf* buf() const {
|
||||
Buf* buf() const {
|
||||
return buf_;
|
||||
}
|
||||
|
||||
explicit Allocate(const Buf* buf) : buf_(buf) {}
|
||||
explicit Allocate(Buf* buf) : buf_(buf) {}
|
||||
|
||||
private:
|
||||
const Buf* buf_;
|
||||
Buf* buf_;
|
||||
// TODO: add memory types.
|
||||
};
|
||||
|
||||
@ -374,18 +374,18 @@ class TORCH_API Free : public StmtNode<Free> {
|
||||
return new Free(buf_handle.node());
|
||||
}
|
||||
|
||||
const Var* buffer_var() const {
|
||||
Var* buffer_var() const {
|
||||
return buf_->base_handle();
|
||||
}
|
||||
|
||||
const Buf* buf() const {
|
||||
Buf* buf() const {
|
||||
return buf_;
|
||||
}
|
||||
|
||||
explicit Free(const Buf* buf) : buf_(buf) {}
|
||||
explicit Free(Buf* buf) : buf_(buf) {}
|
||||
|
||||
private:
|
||||
const Buf* buf_;
|
||||
Buf* buf_;
|
||||
};
|
||||
|
||||
class TORCH_API Let : public StmtNode<Let> {
|
||||
@ -394,25 +394,24 @@ class TORCH_API Let : public StmtNode<Let> {
|
||||
return new Let(var.node(), val.node());
|
||||
}
|
||||
|
||||
Let(const Var* var, const Expr* val)
|
||||
: dtype_(var->dtype()), var_(var), val_(val) {}
|
||||
Let(Var* var, Expr* val) : dtype_(var->dtype()), var_(var), val_(val) {}
|
||||
|
||||
Dtype dtype() const {
|
||||
return dtype_;
|
||||
}
|
||||
|
||||
const Var* var() const {
|
||||
Var* var() const {
|
||||
return var_;
|
||||
}
|
||||
|
||||
const Expr* value() const {
|
||||
Expr* value() const {
|
||||
return val_;
|
||||
}
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
const Var* var_;
|
||||
const Expr* val_;
|
||||
Var* var_;
|
||||
Expr* val_;
|
||||
};
|
||||
|
||||
class TORCH_API Cond : public StmtNode<Cond> {
|
||||
@ -424,7 +423,7 @@ class TORCH_API Cond : public StmtNode<Cond> {
|
||||
return new Cond(condition.node(), true_stmt, false_stmt);
|
||||
}
|
||||
|
||||
const Expr* condition() const {
|
||||
Expr* condition() const {
|
||||
return condition_;
|
||||
}
|
||||
|
||||
@ -436,7 +435,7 @@ class TORCH_API Cond : public StmtNode<Cond> {
|
||||
return false_stmt_;
|
||||
}
|
||||
|
||||
Cond(const Expr* condition, Stmt* true_stmt, Stmt* false_stmt)
|
||||
Cond(Expr* condition, Stmt* true_stmt, Stmt* false_stmt)
|
||||
: condition_(condition) {
|
||||
if (true_stmt) {
|
||||
Block* b = dynamic_cast<Block*>(true_stmt);
|
||||
@ -465,7 +464,7 @@ class TORCH_API Cond : public StmtNode<Cond> {
|
||||
}
|
||||
|
||||
private:
|
||||
const Expr* condition_;
|
||||
Expr* condition_;
|
||||
Block* true_stmt_ = nullptr;
|
||||
Block* false_stmt_ = nullptr;
|
||||
};
|
||||
@ -586,12 +585,11 @@ class TORCH_API LoopOptions {
|
||||
!is_parallel_;
|
||||
}
|
||||
|
||||
void set_buffer_mapping(
|
||||
const std::unordered_map<std::string, const Buf*>& map) {
|
||||
void set_buffer_mapping(const std::unordered_map<std::string, Buf*>& map) {
|
||||
map_input_to_tensor_bufs_ = map;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, const Buf*> get_buffer_mapping() const {
|
||||
std::unordered_map<std::string, Buf*> get_buffer_mapping() const {
|
||||
return map_input_to_tensor_bufs_;
|
||||
}
|
||||
|
||||
@ -599,18 +597,18 @@ class TORCH_API LoopOptions {
|
||||
int gpu_block_index_{IDX_UNSET};
|
||||
int gpu_thread_index_{IDX_UNSET};
|
||||
bool is_parallel_{false};
|
||||
std::unordered_map<std::string, const Buf*> map_input_to_tensor_bufs_;
|
||||
std::unordered_map<std::string, Buf*> map_input_to_tensor_bufs_;
|
||||
};
|
||||
|
||||
class TORCH_API For : public StmtNode<For> {
|
||||
public:
|
||||
const Var* var() const {
|
||||
Var* var() const {
|
||||
return var_;
|
||||
}
|
||||
const Expr* start() const {
|
||||
Expr* start() const {
|
||||
return start_;
|
||||
}
|
||||
const Expr* stop() const {
|
||||
Expr* stop() const {
|
||||
return stop_;
|
||||
}
|
||||
Block* body() const {
|
||||
@ -641,7 +639,7 @@ class TORCH_API For : public StmtNode<For> {
|
||||
return loop_options_;
|
||||
}
|
||||
|
||||
For(const Var* var, const Expr* start, const Expr* stop, Stmt* body)
|
||||
For(Var* var, Expr* start, Expr* stop, Stmt* body)
|
||||
: var_(var), start_(start), stop_(stop) {
|
||||
Block* b = dynamic_cast<Block*>(body);
|
||||
if (!b) {
|
||||
@ -651,11 +649,7 @@ class TORCH_API For : public StmtNode<For> {
|
||||
set_parent(body_, this);
|
||||
}
|
||||
|
||||
For(const Var* var,
|
||||
const Expr* start,
|
||||
const Expr* stop,
|
||||
Stmt* body,
|
||||
LoopOptions loop_options)
|
||||
For(Var* var, Expr* start, Expr* stop, Stmt* body, LoopOptions loop_options)
|
||||
: var_(var),
|
||||
start_(start),
|
||||
stop_(stop),
|
||||
@ -694,7 +688,7 @@ class TORCH_API For : public StmtNode<For> {
|
||||
return loop_options_.is_parallel();
|
||||
}
|
||||
|
||||
void set_buffer_map(const std::unordered_map<std::string, const Buf*>& map) {
|
||||
void set_buffer_map(const std::unordered_map<std::string, Buf*>& map) {
|
||||
loop_options_.set_buffer_mapping(map);
|
||||
}
|
||||
|
||||
@ -719,25 +713,25 @@ class TORCH_API For : public StmtNode<For> {
|
||||
return body_;
|
||||
}
|
||||
|
||||
const Expr* setStart(const Expr* start) {
|
||||
Expr* setStart(Expr* start) {
|
||||
start_ = start;
|
||||
return start_;
|
||||
}
|
||||
|
||||
const Expr* setStop(const Expr* stop) {
|
||||
Expr* setStop(Expr* stop) {
|
||||
stop_ = stop;
|
||||
return stop_;
|
||||
}
|
||||
|
||||
const Var* setVar(const Var* var) {
|
||||
Var* setVar(Var* var) {
|
||||
var_ = var;
|
||||
return var_;
|
||||
}
|
||||
|
||||
private:
|
||||
const Var* var_;
|
||||
const Expr* start_;
|
||||
const Expr* stop_;
|
||||
Var* var_;
|
||||
Expr* start_;
|
||||
Expr* stop_;
|
||||
Block* body_;
|
||||
LoopOptions loop_options_;
|
||||
};
|
||||
@ -749,34 +743,34 @@ class TORCH_API For : public StmtNode<For> {
|
||||
class TORCH_API AtomicAdd : public StmtNode<AtomicAdd> {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
AtomicAdd(const Buf* buf, std::vector<const Expr*> indices, const Expr* value)
|
||||
AtomicAdd(Buf* buf, std::vector<Expr*> indices, Expr* value)
|
||||
: buf_(buf), indices_(std::move(indices)), value_(value) {}
|
||||
|
||||
const Var* base_handle() const {
|
||||
Var* base_handle() const {
|
||||
return buf_->base_handle();
|
||||
}
|
||||
|
||||
const Buf* buf() const {
|
||||
Buf* buf() const {
|
||||
return buf_;
|
||||
}
|
||||
|
||||
const Expr* flat_index() const {
|
||||
Expr* flat_index() const {
|
||||
TORCH_CHECK(indices_.size() == 1, "Indices haven't been flattened.");
|
||||
return indices_[0];
|
||||
}
|
||||
|
||||
const Expr* value() const {
|
||||
Expr* value() const {
|
||||
return value_;
|
||||
}
|
||||
|
||||
const std::vector<const Expr*>& indices() const {
|
||||
const std::vector<Expr*>& indices() const {
|
||||
return indices_;
|
||||
}
|
||||
|
||||
private:
|
||||
const Buf* buf_;
|
||||
std::vector<const Expr*> indices_;
|
||||
const Expr* value_;
|
||||
Buf* buf_;
|
||||
std::vector<Expr*> indices_;
|
||||
Expr* value_;
|
||||
};
|
||||
|
||||
class TORCH_API SyncThreads : public StmtNode<SyncThreads> {
|
||||
@ -811,7 +805,7 @@ class TORCH_API ExternalCall : public StmtNode<ExternalCall> {
|
||||
const std::vector<BufHandle>& buf_args,
|
||||
const std::vector<ExprHandle>& args);
|
||||
|
||||
const Buf* buf() const {
|
||||
Buf* buf() const {
|
||||
return buf_;
|
||||
}
|
||||
|
||||
@ -819,30 +813,30 @@ class TORCH_API ExternalCall : public StmtNode<ExternalCall> {
|
||||
return func_name_;
|
||||
}
|
||||
|
||||
std::vector<const Buf*> buf_args() const {
|
||||
std::vector<Buf*> buf_args() const {
|
||||
return buf_args_;
|
||||
}
|
||||
|
||||
std::vector<const Expr*> args() const {
|
||||
std::vector<Expr*> args() const {
|
||||
return args_;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
ExternalCall(
|
||||
const Buf* buf,
|
||||
Buf* buf,
|
||||
std::string func_name,
|
||||
std::vector<const Buf*> buf_args,
|
||||
std::vector<const Expr*> args)
|
||||
std::vector<Buf*> buf_args,
|
||||
std::vector<Expr*> args)
|
||||
: buf_(buf),
|
||||
func_name_(std::move(func_name)),
|
||||
buf_args_(std::move(buf_args)),
|
||||
args_(std::move(args)) {}
|
||||
|
||||
private:
|
||||
const Buf* buf_;
|
||||
Buf* buf_;
|
||||
std::string func_name_;
|
||||
std::vector<const Buf*> buf_args_;
|
||||
std::vector<const Expr*> args_;
|
||||
std::vector<Buf*> buf_args_;
|
||||
std::vector<Expr*> args_;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
@ -10,11 +10,11 @@ namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
Stmt* Tensor::constructStmt(
|
||||
const std::vector<const Var*>& args,
|
||||
const Expr* body,
|
||||
const std::vector<const Expr*>& reduce_dims,
|
||||
const std::vector<const Var*>& reduce_args) const {
|
||||
std::vector<const Expr*> indices(args.begin(), args.end());
|
||||
const std::vector<Var*>& args,
|
||||
Expr* body,
|
||||
const std::vector<Expr*>& reduce_dims,
|
||||
const std::vector<Var*>& reduce_args) const {
|
||||
std::vector<Expr*> indices(args.begin(), args.end());
|
||||
|
||||
Stmt* s = new Store(buf_, indices, body);
|
||||
|
||||
@ -25,10 +25,10 @@ Stmt* Tensor::constructStmt(
|
||||
return s;
|
||||
}
|
||||
|
||||
const Expr* init_expr = buf()->initializer();
|
||||
Expr* init_expr = buf()->initializer();
|
||||
|
||||
if (reduce_ndim > 0) {
|
||||
for (const auto i : c10::irange(reduce_ndim)) {
|
||||
for (auto i : c10::irange(reduce_ndim)) {
|
||||
// Going in reverse order: from innermost loop to the outermost
|
||||
size_t dim_index = reduce_ndim - i - 1;
|
||||
s = new For(
|
||||
@ -40,7 +40,7 @@ Stmt* Tensor::constructStmt(
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto i : c10::irange(ndim)) {
|
||||
for (auto i : c10::irange(ndim)) {
|
||||
// Going in reverse order: from innermost loop to the outermost
|
||||
size_t dim_index = ndim - i - 1;
|
||||
s = new For(args[dim_index], new IntImm(0), buf()->dim(dim_index), s);
|
||||
@ -52,11 +52,11 @@ Tensor* Compute(
|
||||
const std::string& name,
|
||||
const std::vector<DimArg>& dim_args,
|
||||
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
|
||||
std::vector<const Expr*> dims;
|
||||
std::vector<const Var*> args;
|
||||
std::vector<Expr*> dims;
|
||||
std::vector<Var*> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
const Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
|
||||
const Buf* buf = new Buf(name, dims, body->dtype());
|
||||
Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
|
||||
Buf* buf = new Buf(name, dims, body->dtype());
|
||||
return new Tensor(buf, args, body);
|
||||
}
|
||||
|
||||
@ -68,11 +68,11 @@ Tensor* Compute(
|
||||
throw malformed_input("mismatch between body and arg size (1)");
|
||||
}
|
||||
|
||||
std::vector<const Expr*> dims;
|
||||
std::vector<const Var*> args;
|
||||
std::vector<Expr*> dims;
|
||||
std::vector<Var*> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
const Expr* body = body_func(VarHandle(args[0])).node();
|
||||
const Buf* buf = new Buf(name, dims, body->dtype());
|
||||
Expr* body = body_func(VarHandle(args[0])).node();
|
||||
Buf* buf = new Buf(name, dims, body->dtype());
|
||||
return new Tensor(buf, args, body);
|
||||
}
|
||||
|
||||
@ -84,11 +84,11 @@ Tensor* Compute(
|
||||
if (dim_args.size() != 2) {
|
||||
throw malformed_input("mismatch between body and arg size (2)");
|
||||
}
|
||||
std::vector<const Expr*> dims;
|
||||
std::vector<const Var*> args;
|
||||
std::vector<Expr*> dims;
|
||||
std::vector<Var*> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
|
||||
const Buf* buf = new Buf(name, dims, body->dtype());
|
||||
Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
|
||||
Buf* buf = new Buf(name, dims, body->dtype());
|
||||
return new Tensor(buf, args, body);
|
||||
}
|
||||
|
||||
@ -101,13 +101,13 @@ Tensor* Compute(
|
||||
if (dim_args.size() != 3) {
|
||||
throw malformed_input("mismatch between body and arg size (3)");
|
||||
}
|
||||
std::vector<const Expr*> dims;
|
||||
std::vector<const Var*> args;
|
||||
std::vector<Expr*> dims;
|
||||
std::vector<Var*> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
const Expr* body =
|
||||
Expr* body =
|
||||
body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
|
||||
.node();
|
||||
const Buf* buf = new Buf(name, dims, body->dtype());
|
||||
Buf* buf = new Buf(name, dims, body->dtype());
|
||||
return new Tensor(buf, args, body);
|
||||
}
|
||||
|
||||
@ -122,16 +122,16 @@ Tensor* Compute(
|
||||
if (dim_args.size() != 4) {
|
||||
throw malformed_input("mismatch between body and arg size (4)");
|
||||
}
|
||||
std::vector<const Expr*> dims;
|
||||
std::vector<const Var*> args;
|
||||
std::vector<Expr*> dims;
|
||||
std::vector<Var*> args;
|
||||
unpack_dim_args(dim_args, &dims, &args);
|
||||
const Expr* body = body_func(
|
||||
VarHandle(args[0]),
|
||||
VarHandle(args[1]),
|
||||
VarHandle(args[2]),
|
||||
VarHandle(args[3]))
|
||||
.node();
|
||||
const Buf* buf = new Buf(name, dims, body->dtype());
|
||||
Expr* body = body_func(
|
||||
VarHandle(args[0]),
|
||||
VarHandle(args[1]),
|
||||
VarHandle(args[2]),
|
||||
VarHandle(args[3]))
|
||||
.node();
|
||||
Buf* buf = new Buf(name, dims, body->dtype());
|
||||
return new Tensor(buf, args, body);
|
||||
}
|
||||
|
||||
|
@ -15,25 +15,24 @@ namespace tensorexpr {
|
||||
class TORCH_API Tensor : KernelScopedObject {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Tensor(const Buf* buf, const std::vector<const Var*>& args, const Expr* body)
|
||||
: buf_(buf) {
|
||||
Tensor(Buf* buf, const std::vector<Var*>& args, Expr* body) : buf_(buf) {
|
||||
stmt_ = constructStmt(args, body, {}, {});
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
Tensor(
|
||||
const Buf* buf,
|
||||
const std::vector<const Var*>& args,
|
||||
const std::vector<const Expr*>& reduce_dims,
|
||||
const std::vector<const Var*>& reduce_args,
|
||||
const Expr* body)
|
||||
Buf* buf,
|
||||
const std::vector<Var*>& args,
|
||||
const std::vector<Expr*>& reduce_dims,
|
||||
const std::vector<Var*>& reduce_args,
|
||||
Expr* body)
|
||||
: buf_(buf) {
|
||||
stmt_ = constructStmt(args, body, reduce_dims, reduce_args);
|
||||
}
|
||||
|
||||
Tensor(const Buf* buf, Stmt* stmt) : buf_(buf), stmt_(stmt) {}
|
||||
Tensor(Buf* buf, Stmt* stmt) : buf_(buf), stmt_(stmt) {}
|
||||
|
||||
const Buf* buf() const {
|
||||
Buf* buf() const {
|
||||
return buf_;
|
||||
}
|
||||
|
||||
@ -48,12 +47,12 @@ class TORCH_API Tensor : KernelScopedObject {
|
||||
|
||||
private:
|
||||
Stmt* constructStmt(
|
||||
const std::vector<const Var*>& args,
|
||||
const Expr* body,
|
||||
const std::vector<const Expr*>& reduce_dims,
|
||||
const std::vector<const Var*>& reduce_args) const;
|
||||
const std::vector<Var*>& args,
|
||||
Expr* body,
|
||||
const std::vector<Expr*>& reduce_dims,
|
||||
const std::vector<Var*>& reduce_args) const;
|
||||
|
||||
const Buf* buf_;
|
||||
Buf* buf_;
|
||||
Stmt* stmt_;
|
||||
};
|
||||
|
||||
@ -96,7 +95,7 @@ class Placeholder {
|
||||
explicit Placeholder(const std::vector<ExprHandle>& dims)
|
||||
: Placeholder(BufHandle("_", dims, kFloat)) {}
|
||||
|
||||
const Buf* data() const {
|
||||
Buf* data() const {
|
||||
return data_;
|
||||
}
|
||||
BufHandle handle() const {
|
||||
@ -108,10 +107,10 @@ class Placeholder {
|
||||
int ndim() const {
|
||||
return data_->ndim();
|
||||
}
|
||||
const Expr* dim(int index) const {
|
||||
Expr* dim(int index) const {
|
||||
return data_->dim(index);
|
||||
}
|
||||
std::vector<const Expr*> dims() const {
|
||||
std::vector<Expr*> dims() const {
|
||||
return data_->dims();
|
||||
}
|
||||
|
||||
@ -130,8 +129,8 @@ class Placeholder {
|
||||
}
|
||||
|
||||
private:
|
||||
const Buf* data_;
|
||||
std::vector<const Expr*> strides_;
|
||||
Buf* data_;
|
||||
std::vector<Expr*> strides_;
|
||||
};
|
||||
|
||||
TORCH_API Tensor* Compute(
|
||||
@ -164,12 +163,12 @@ TORCH_API Tensor* Compute(
|
||||
|
||||
inline void unpack_dim_args(
|
||||
const std::vector<DimArg>& dim_args,
|
||||
std::vector<const Expr*>* dims,
|
||||
std::vector<const Var*>* vars) {
|
||||
std::vector<Expr*>* dims,
|
||||
std::vector<Var*>* vars) {
|
||||
dims->clear();
|
||||
vars->clear();
|
||||
for (const DimArg& dim_arg : dim_args) {
|
||||
const Expr* expr = dim_arg.dim().node();
|
||||
Expr* expr = dim_arg.dim().node();
|
||||
dims->push_back(expr);
|
||||
vars->push_back(new Var(
|
||||
dim_arg.name_hint(),
|
||||
@ -187,22 +186,22 @@ Tensor* Reduce(
|
||||
const BodyFunc& body_func,
|
||||
const std::vector<DimArg>& reduce_args) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Expr*> dims;
|
||||
std::vector<Expr*> dims;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Var*> vars;
|
||||
std::vector<Var*> vars;
|
||||
unpack_dim_args(dim_args, &dims, &vars);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Expr*> reduce_dims;
|
||||
std::vector<Expr*> reduce_dims;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Var*> reduce_vars;
|
||||
std::vector<Var*> reduce_vars;
|
||||
unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars);
|
||||
|
||||
// If reduce_vars is empty, then it's not a reduction, but rather a simple
|
||||
// copy
|
||||
if (reduce_vars.empty()) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
const Expr* body =
|
||||
Expr* body =
|
||||
Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(vars))
|
||||
.node();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
@ -211,22 +210,21 @@ Tensor* Reduce(
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Var*> all_vars;
|
||||
std::vector<Var*> all_vars;
|
||||
all_vars.insert(all_vars.end(), vars.begin(), vars.end());
|
||||
all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end());
|
||||
|
||||
ExprHandle body =
|
||||
Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(all_vars));
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Expr*> output_args(vars.begin(), vars.end());
|
||||
std::vector<Expr*> output_args(vars.begin(), vars.end());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
const Expr* init_expr = new Cast(
|
||||
Expr* init_expr = new Cast(
|
||||
body.dtype(), init_func(VarVectorToVarHandleVector(vars)).node());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
Buf* func_result = new Buf(func_name, dims, body.dtype(), init_expr);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
const ReduceOp* reduce_op =
|
||||
reducer(func_result, body, output_args, reduce_vars);
|
||||
ReduceOp* reduce_op = reducer(func_result, body, output_args, reduce_vars);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
Tensor* t =
|
||||
new Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);
|
||||
|
@ -325,7 +325,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
.def(py::init([](const std::vector<Stmt*>& stmts) {
|
||||
return tensorexpr::Block::make(stmts);
|
||||
}))
|
||||
.def("__str__", [](const Stmt& self) {
|
||||
.def("__str__", [](Stmt& self) {
|
||||
std::stringstream ss;
|
||||
ss << self;
|
||||
return ss.str();
|
||||
@ -343,7 +343,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
py::class_<For, Stmt, std::unique_ptr<For, py::nodelete>>(te, "For")
|
||||
.def(
|
||||
"index_var",
|
||||
[](const For& self) { return VarHandle(self.var()); },
|
||||
[](For& self) { return VarHandle(self.var()); },
|
||||
py::return_value_policy::reference)
|
||||
.def("body", &For::body, py::return_value_policy::reference)
|
||||
.def("set_parallel", &For::set_parallel)
|
||||
@ -393,8 +393,8 @@ void initTensorExprBindings(PyObject* module) {
|
||||
py::class_<LoopNest>(te, "LoopNest")
|
||||
.def(py::init<const std::vector<Tensor*>&>())
|
||||
.def(py::init([](Stmt* s, const std::vector<BufHandle>& bufs) {
|
||||
std::unordered_set<const Buf*> buf_nodes;
|
||||
for (const auto& buf : bufs) {
|
||||
std::unordered_set<Buf*> buf_nodes;
|
||||
for (auto& buf : bufs) {
|
||||
buf_nodes.insert(buf.node());
|
||||
}
|
||||
return std::make_unique<LoopNest>(s, buf_nodes);
|
||||
@ -427,7 +427,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"get_enclosing_loopnest",
|
||||
[](const LoopNest& self, const Stmt* s) {
|
||||
[](const LoopNest& self, Stmt* s) {
|
||||
return self.getEnclosingLoopNest(s);
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
@ -451,9 +451,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"get_parent_loop",
|
||||
[](const LoopNest& self, const Stmt* s) {
|
||||
return self.getParentLoop(s);
|
||||
},
|
||||
[](const LoopNest& self, Stmt* s) { return self.getParentLoop(s); },
|
||||
py::return_value_policy::reference)
|
||||
.def_static(
|
||||
"get_loop_stmts_in_loopnest",
|
||||
@ -566,7 +564,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
[](const BufHandle& producer,
|
||||
const std::string& name,
|
||||
Stmt* consumer) {
|
||||
std::pair<const Buf*, Stmt*> ret =
|
||||
std::pair<Buf*, Stmt*> ret =
|
||||
LoopNest::cacheAccesses(producer.node(), name, consumer);
|
||||
return std::make_pair(BufHandle(ret.first), ret.second);
|
||||
},
|
||||
@ -679,7 +677,7 @@ void initTensorExprBindings(PyObject* module) {
|
||||
std::unordered_map<std::string, NNCLoweringFunction>
|
||||
custom_lowerings_str) {
|
||||
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings;
|
||||
for (const auto& kv : custom_lowerings_str) {
|
||||
for (auto& kv : custom_lowerings_str) {
|
||||
custom_lowerings[c10::Symbol::fromQualString(kv.first)] = kv.second;
|
||||
}
|
||||
return std::make_unique<TensorExprKernel>(g, custom_lowerings);
|
||||
|
@ -8,7 +8,7 @@ namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
const std::string& UniqueNameManager::get_unique_name(const Var* v) {
|
||||
const std::string& UniqueNameManager::get_unique_name(Var* v) {
|
||||
// Find if we have already encountered this variable.
|
||||
auto iter = unique_name_mapping_.find(v);
|
||||
if (iter != unique_name_mapping_.end()) {
|
||||
|
@ -13,7 +13,7 @@ namespace tensorexpr {
|
||||
class VarHandle;
|
||||
class Var;
|
||||
|
||||
using VarNameMap = std::unordered_map<const Var*, std::string>;
|
||||
using VarNameMap = std::unordered_map<Var*, std::string>;
|
||||
|
||||
// A manager to get unique names from vars.
|
||||
// It starts with the name hints of the var and append "_" + $counter until it
|
||||
@ -23,7 +23,7 @@ class TORCH_API UniqueNameManager {
|
||||
public:
|
||||
const std::string& get_unique_name(const VarHandle& v);
|
||||
|
||||
const std::string& get_unique_name(const Var* v);
|
||||
const std::string& get_unique_name(Var* v);
|
||||
|
||||
private:
|
||||
friend class ScopedVarName;
|
||||
|
@ -13,15 +13,15 @@ namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
using VarMapping = std::vector<std::pair<const Var*, const Expr*>>;
|
||||
using VarMapping = std::vector<std::pair<Var*, Expr*>>;
|
||||
|
||||
class VarSubMutator : public IRMutator {
|
||||
public:
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
VarSubMutator(const VarMapping& var_mapping) {
|
||||
for (const auto& entry : var_mapping) {
|
||||
const Var* key_var = entry.first;
|
||||
const Expr* value = entry.second;
|
||||
for (auto& entry : var_mapping) {
|
||||
Var* key_var = entry.first;
|
||||
Expr* value = entry.second;
|
||||
if (!key_var) {
|
||||
throw malformed_input("missing key in VarSubMutator");
|
||||
}
|
||||
@ -29,7 +29,7 @@ class VarSubMutator : public IRMutator {
|
||||
}
|
||||
}
|
||||
|
||||
const Expr* mutate(const Var* var) override {
|
||||
Expr* mutate(Var* var) override {
|
||||
auto iter = var_mapping_.find(var);
|
||||
if (iter == var_mapping_.end()) {
|
||||
return var;
|
||||
@ -37,14 +37,14 @@ class VarSubMutator : public IRMutator {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
const Expr* mutate(const ReduceOp* var) override {
|
||||
Expr* mutate(ReduceOp* var) override {
|
||||
auto body = var->body()->accept_mutator(this);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::vector<const Var*> new_inner;
|
||||
std::vector<Var*> new_inner;
|
||||
|
||||
for (auto* v : var->reduce_args()) {
|
||||
const Expr* e = v->accept_mutator(this);
|
||||
if (const Var* new_var = dynamic_cast<const Var*>(e)) {
|
||||
Expr* e = v->accept_mutator(this);
|
||||
if (Var* new_var = dynamic_cast<Var*>(e)) {
|
||||
new_inner.push_back(new_var);
|
||||
} else {
|
||||
VarFinder varFinder;
|
||||
@ -58,7 +58,7 @@ class VarSubMutator : public IRMutator {
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<const Var*, const Expr*> var_mapping_;
|
||||
std::unordered_map<Var*, Expr*> var_mapping_;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
|
Reference in New Issue
Block a user