[TensorExpr] Add core classes for representing expressions and statements. (#33218)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33218

Test Plan: Imported from OSS

Differential Revision: D19848378

Pulled By: ZolotukhinM

fbshipit-source-id: 48399f8651324d5ad0607e08573d5d7b2026bb23
This commit is contained in:
Mikhail Zolotukhin
2020-02-21 13:06:13 -08:00
committed by Facebook Github Bot
parent 1a4f997178
commit 49af9425a7
13 changed files with 1300 additions and 2 deletions

View File

@ -456,7 +456,10 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/vararg_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/expr.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/types.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/unique_name_manager.cpp
)
if (NOT INTERN_DISABLE_MOBILE_INTERP)

View File

@ -0,0 +1,58 @@
#include "test/cpp/tensorexpr/test_base.h"
#include "test/cpp/tensorexpr/test_utils.h"
#include "torch/csrc/jit/tensorexpr/buffer.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
#include <cmath>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
void testExprVectorAdd01() {
KernelScope kernel_scope;
const int kVectorSize = 8;
const int kVectorCount = 128;
const int kTotalSize = kVectorSize * kVectorCount;
Buffer a_buf(Var("A", kHandle), kFloat32, {Expr(kTotalSize)});
Buffer b_buf(Var("B", kHandle), kFloat32, {Expr(kTotalSize)});
Buffer c_buf(Var("C", kHandle), kFloat32, {Expr(kTotalSize)});
/*
Build the following:
for (int index = 0; index < kVectorCount; index++) {
store(c_buf, ramp(index * 8, 1, 8),
load(a_buf, ramp(index * 8, 1, 8) +
load(b_buf, ramp(index * 8, 1, 8))))
}
*/
Var index = Var("index", kInt32);
Expr load_a = Load::make(
a_buf,
Ramp::make(index * kVectorSize, 1, kVectorSize),
Broadcast::make(1, kVectorSize));
Expr load_b = Load::make(
b_buf,
Ramp::make(index * kVectorSize, 1, kVectorSize),
Broadcast::make(1, kVectorSize));
Expr value = load_a + load_b;
Stmt store_c = Store::make(
c_buf,
Ramp::make(index * kVectorSize, 1, kVectorSize),
value,
Broadcast::make(1, kVectorSize));
Stmt stmt = For::make(index, 0, kVectorCount, store_c);
EXPECT_EQ(load_a.dtype(), Dtype(kFloat32, kVectorSize));
EXPECT_EQ(load_b.dtype(), Dtype(kFloat32, kVectorSize));
EXPECT_EQ(value.dtype(), Dtype(kFloat32, kVectorSize));
}
} // namespace jit
} // namespace torch

View File

@ -1,6 +1,5 @@
#include "test/cpp/tensorexpr/test_base.h"
#include "torch/csrc/jit/tensorexpr/mem_arena.h"
#include "torch/csrc/jit/tensorexpr/types.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
namespace torch {
namespace jit {

View File

@ -9,6 +9,7 @@
namespace torch {
namespace jit {
#define TH_FORALL_TESTS(_) \
_(ExprVectorAdd01) \
_(TypeTest01) \
#define TH_FORALL_TESTS_CUDA(_) \

View File

@ -190,13 +190,19 @@ libtorch_sources = [
"torch/csrc/jit/mobile/register_mobile_ops.cpp",
"torch/csrc/jit/mobile/interpreter.cpp",
"torch/csrc/jit/mobile/type_parser.cpp",
"torch/csrc/jit/tensorexpr/expr.cpp",
"torch/csrc/jit/tensorexpr/ir.cpp",
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
"torch/csrc/jit/tensorexpr/types.cpp",
"torch/csrc/jit/tensorexpr/unique_name_manager.cpp",
"torch/csrc/utils/byte_order.cpp",
"torch/csrc/utils/tensor_flatten.cpp",
"torch/csrc/utils/variadic.cpp",
"torch/csrc/jit/tensorexpr/expr.cpp",
"torch/csrc/jit/tensorexpr/ir.cpp",
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
"torch/csrc/jit/tensorexpr/types.cpp",
"torch/csrc/jit/tensorexpr/unique_name_manager.cpp",
]
libtorch_cuda_sources = [

View File

View File

@ -0,0 +1,106 @@
#pragma once
#include "torch/csrc/jit/tensorexpr/ir.h"
namespace torch {
namespace jit {
namespace tensorexpr {
class Buffer {
public:
Buffer(const Var& data, const Dtype& dtype, const std::vector<Expr>& dims)
: data_(data), dtype_(dtype), dims_(dims), strides_(dims.size()) {
CHECK_EQ(data.dtype(), kHandle);
for (int i = ndim() - 1; i >= 0; i--) {
if (i == ndim() - 1) {
strides_[i] = 1;
} else {
strides_[i] = strides_[i + 1] * dim(i + 1);
}
}
}
Buffer(
const std::string& name,
const Dtype& dtype,
const std::vector<Expr>& dims)
: Buffer(Var(name, kHandle), dtype, dims) {}
const Var& data() const {
return data_;
}
const Dtype& dtype() const {
return dtype_;
}
int ndim() const {
return dims_.size();
}
const Expr& dim(int index) const {
return dims_[index];
}
// TODO: consider defer the storage flatten to a later stage.
template <typename... Args>
Expr operator()(Args... args) const {
Expr index = Index(std::forward<Args>(args)...);
return LoadValue(index);
}
template <typename T>
Expr call(const std::vector<T>& args) const {
std::vector<Expr> params(args.begin(), args.end());
Expr index = Index(params);
return LoadValue(index);
}
private:
Expr Index(const Expr& x) const {
CHECK(ndim() == 1);
return x;
}
Expr Index(const Expr& x, const Expr& y) const {
CHECK(ndim() == 2);
return x * strides_[0] + y;
}
Expr Index(const Expr& x, const Expr& y, const Expr& z) const {
CHECK(ndim() == 3);
return x * strides_[0] + y * strides_[1] + z;
}
Expr Index(const Expr& x, const Expr& y, const Expr& z, const Expr& w) const {
CHECK(ndim() == 4);
return x * strides_[0] + y * strides_[1] + z * strides_[2] + w;
}
Expr Index(const std::vector<Expr>& indices) const {
CHECK(ndim() == (int)indices.size());
Expr total_index;
for (size_t i = 0; i < indices.size(); i++) {
Expr index;
if (i == indices.size() - 1) {
index = indices[i];
} else {
index = indices[i] * strides_[i];
}
if (i == 0) {
total_index = index;
} else {
total_index = total_index + index;
}
}
return total_index;
}
Expr LoadValue(const Expr& index) const;
Var data_;
Dtype dtype_;
std::vector<Expr> dims_;
std::vector<Expr> strides_;
// TODO: add strides
};
inline Expr Buffer::LoadValue(const Expr& index) const {
return Load::make(*this, index, Expr(1));
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,59 @@
#include "torch/csrc/jit/tensorexpr/expr.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
namespace torch {
namespace jit {
namespace tensorexpr {
Expr Expr::operator+(const Expr& other) const {
return Add::make(*this, other);
}
Expr Expr::operator-(const Expr& other) const {
return Sub::make(*this, other);
}
Expr Expr::operator*(const Expr& other) const {
return Mul::make(*this, other);
}
Expr Expr::operator/(const Expr& other) const {
return Div::make(*this, other);
}
Expr Expr::operator==(const Expr& other) const {
return CompareSelect::make(*this, other, CompareSelectOperation::kEQ);
}
Expr Expr::operator!=(const Expr& other) const {
return CompareSelect::make(*this, other, CompareSelectOperation::kNE);
}
Expr Expr::operator>(const Expr& other) const {
return CompareSelect::make(*this, other, CompareSelectOperation::kGT);
}
Expr Expr::operator>=(const Expr& other) const {
return CompareSelect::make(*this, other, CompareSelectOperation::kGE);
}
Expr Expr::operator<(const Expr& other) const {
return CompareSelect::make(*this, other, CompareSelectOperation::kLT);
}
Expr Expr::operator<=(const Expr& other) const {
return CompareSelect::make(*this, other, CompareSelectOperation::kLE);
}
Expr::Expr(int v) : Expr(IntImm::make(v)) {}
Expr::Expr(float v) : Expr(FloatImm::make(v)) {}
Expr ifThenElse(const Expr& c, const Expr& t, const Expr& f) {
return IfThenElse::make(c, t, f);
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,143 @@
/**
* This file implements the core classes for Tensor Expressions.
*
* The structure of the expressions is inspired by Halide/TVM IR.
*/
#pragma once
#include "torch/csrc/jit/tensorexpr/types.h"
#include "torch/csrc/jit/tensorexpr/mem_arena.h"
namespace torch {
namespace jit {
namespace tensorexpr {
// The common base between all expression node.
class Expr;
class BaseExprNode : public KernelScopedObject {
public:
explicit BaseExprNode(Dtype dtype) : dtype_(dtype) {}
Dtype dtype() const {
return dtype_;
}
private:
Dtype dtype_;
};
// The common base between all statement node.
class BaseStmtNode : public KernelScopedObject {
public:
BaseStmtNode() {}
};
// A CRTP pattern to accept visitors for children class,
// and dispatch back to the children.
template <class Op, class Base = BaseExprNode>
class ExprNode : public Base {
public:
using ExprNodeBase = ExprNode<Op>;
// pass the constructor to the base class
using Base::Base;
};
template <class Op>
class StmtNode : public BaseStmtNode {
public:
using StmtNodeBase = StmtNode<Op>;
StmtNode() {}
};
// A wrapper object to the underlying ExprNode.
// Also serves the primary way to build and operate on other expressions.
class TORCH_API Expr {
public:
Expr() {}
explicit Expr(const BaseExprNode* node)
: base_expr_node_(const_cast<BaseExprNode*>(node)) {}
BaseExprNode* node() {
return base_expr_node_;
}
const BaseExprNode* node() const {
return base_expr_node_;
}
bool empty() const {
return base_expr_node_ == nullptr;
}
Expr(int v);
Expr(float v);
template <class Op>
Op* AsNode() {
return dynamic_cast<Op*>(this->node());
}
template <class Op>
const Op* AsNode() const {
return const_cast<Expr*>(this)->AsNode<Op>();
}
Dtype dtype() const {
return node()->dtype();
}
// Handling the math operators.
Expr operator+(const Expr& other) const;
Expr operator-(const Expr& other) const;
Expr operator*(const Expr& other) const;
Expr operator/(const Expr& other) const;
Expr operator==(const Expr& other) const;
Expr operator!=(const Expr& other) const;
Expr operator>(const Expr& other) const;
Expr operator>=(const Expr& other) const;
Expr operator<(const Expr& other) const;
Expr operator<=(const Expr& other) const;
private:
BaseExprNode* base_expr_node_ = nullptr;
};
class Stmt {
public:
Stmt() {}
explicit Stmt(const BaseStmtNode* node)
: base_stmt_node_(const_cast<BaseStmtNode*>(node)) {}
BaseStmtNode* node() {
return base_stmt_node_;
}
const BaseStmtNode* node() const {
return base_stmt_node_;
}
bool empty() const {
return node() == nullptr;
}
template <class Op>
const Op* AsNode() const {
return dynamic_cast<const Op*>(this->node());
}
private:
BaseStmtNode* base_stmt_node_ = nullptr;
};
inline bool same_node(const Expr& expr1, const Expr& expr2) {
return expr1.AsNode<BaseExprNode>() == expr2.AsNode<BaseExprNode>();
}
inline bool same_node(const Stmt& stmt1, const Stmt& stmt2) {
return stmt1.AsNode<BaseStmtNode>() == stmt2.AsNode<BaseStmtNode>();
}
TORCH_API Expr ifThenElse(const Expr& c, const Expr& t, const Expr& f);
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,46 @@
#include "torch/csrc/jit/tensorexpr/ir.h"
#include "torch/csrc/jit/tensorexpr/buffer.h"
namespace torch {
namespace jit {
namespace tensorexpr {
static Dtype ChooseDtype(const Dtype& buffer_dtype, const Dtype& index_dtype) {
return Dtype(buffer_dtype, index_dtype.lanes());
}
Load::Load(const Buffer& buffer, const Expr& index, const Expr& mask)
: Load(
ChooseDtype(buffer.dtype(), index.dtype()),
buffer.data(),
index,
mask) {}
Load::Load(
Dtype dtype,
const Var& base_handle,
const Expr& index,
const Expr& mask)
: ExprNodeBase(dtype),
base_handle_(base_handle),
index_(index),
mask_(mask) {
CHECK_EQ(base_handle_.dtype(), kHandle);
CHECK_EQ(index.dtype().lanes(), mask.dtype().lanes());
CHECK_EQ(index.dtype().scalar_type(), kInt32);
}
Store::Store(
const Buffer& buffer,
const Expr& index,
const Expr& value,
const Expr& mask)
: Store(buffer.data(), index, value, mask) {
CHECK_EQ(buffer.dtype().scalar_type(), value.dtype().scalar_type());
CHECK_EQ(buffer.dtype().scalar_type(), value.dtype().scalar_type());
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,793 @@
#pragma once
#include <string>
#include <vector>
#include "torch/csrc/jit/tensorexpr/expr.h"
namespace torch {
namespace jit {
namespace tensorexpr {
enum IRNodeType {
kAdd,
kSub,
kMul,
kDiv,
kMod,
kMax,
kMin,
kCompareSelect,
};
enum CompareSelectOperation {
kEQ,
kGT,
kGE,
kLT,
kLE,
kNE,
};
class Buffer;
class Cast : public ExprNode<Cast> {
public:
const Expr& src_value() const {
return src_value_;
}
static Expr make(Dtype dtype, const Expr& src_value) {
return Expr(new Cast(dtype, src_value));
}
private:
Cast(Dtype dtype, const Expr& src_value)
: ExprNodeBase(dtype), src_value_(src_value) {}
Expr src_value_;
};
template <typename T>
Expr cast(const Expr& src_value) {
return Cast::make(Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
}
// Represent the expression node for binary operators.
// A CRTP pattern to share common code among the operators.
template <typename Op>
class BinaryOpNode : public ExprNode<Op> {
public:
const Expr& lhs() const {
return this->lhs_;
}
const Expr& rhs() const {
return this->rhs_;
}
IRNodeType expr_type() const {
return expr_type_;
}
static Expr make(const Expr& lhs, const Expr& rhs) {
return Expr(new Op(lhs, rhs));
}
protected:
BinaryOpNode(
const Expr& lhs_v,
const Expr& rhs_v,
IRNodeType expr_type,
ReturnType ret_type = ReturnType::knone)
: ExprNode<Op>(BinaryOpDtype(lhs_v.dtype(), rhs_v.dtype(), ret_type)),
lhs_(CastIfNeeded(lhs_v, ExprNode<Op>::dtype())),
rhs_(CastIfNeeded(rhs_v, ExprNode<Op>::dtype())),
expr_type_(expr_type) {}
private:
static Expr CastIfNeeded(const Expr& expr, Dtype dst_dtype) {
if (expr.dtype() == dst_dtype) {
return expr;
}
return Cast::make(dst_dtype, expr);
}
Expr lhs_;
Expr rhs_;
IRNodeType expr_type_;
};
class Add : public BinaryOpNode<Add> {
private:
Add(const Expr& lhs, const Expr& rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
friend class BinaryOpNode<Add>;
};
class Sub : public BinaryOpNode<Sub> {
private:
Sub(const Expr& lhs, const Expr& rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kSub) {}
friend class BinaryOpNode<Sub>;
};
class Mul : public BinaryOpNode<Mul> {
private:
Mul(const Expr& lhs, const Expr& rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kMul) {}
friend class BinaryOpNode<Mul>;
};
class Div : public BinaryOpNode<Div> {
private:
Div(const Expr& lhs, const Expr& rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {}
friend class BinaryOpNode<Div>;
};
class Mod : public BinaryOpNode<Mod> {
private:
Mod(const Expr& lhs, const Expr& rhs)
: BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
friend class BinaryOpNode<Mod>;
};
class Max : public BinaryOpNode<Max> {
private:
bool propagate_nans_;
Max(const Expr& lhs, const Expr& rhs, bool propagate_nans)
: BinaryOpNode(lhs, rhs, IRNodeType::kMax),
propagate_nans_(propagate_nans) {}
friend class BinaryOpNode<Max>;
public:
bool propagate_nans() const {
return propagate_nans_;
}
static Expr make(const Expr& lhs, const Expr& rhs) = delete;
static Expr make(const Expr& lhs, const Expr& rhs, bool propagate_nans) {
return Expr(new Max(lhs, rhs, propagate_nans));
}
};
class Min : public BinaryOpNode<Min> {
private:
bool propagate_nans_;
Min(const Expr& lhs, const Expr& rhs, bool propagate_nans)
: BinaryOpNode(lhs, rhs, IRNodeType::kMin),
propagate_nans_(propagate_nans) {}
friend class BinaryOpNode<Min>;
public:
bool propagate_nans() const {
return propagate_nans_;
}
static Expr make(const Expr& lhs, const Expr& rhs) = delete;
static Expr make(const Expr& lhs, const Expr& rhs, bool propagate_nans) {
return Expr(new Min(lhs, rhs, propagate_nans));
}
};
class CompareSelect : public ExprNode<CompareSelect> {
public:
CompareSelectOperation compare_select_op() const {
return compare_op_;
}
const Expr& lhs() const {
return this->lhs_;
}
const Expr& rhs() const {
return this->rhs_;
}
static Expr make(const Expr& lhs, const Expr& rhs) = delete;
static Expr make(
const Expr& lhs,
const Expr& rhs,
CompareSelectOperation cmp_op) {
return Expr(new CompareSelect(lhs, rhs, cmp_op));
}
private:
Expr lhs_;
Expr rhs_;
CompareSelectOperation compare_op_;
CompareSelect(const Expr& lhs, const Expr& rhs, CompareSelectOperation cmp_op)
: ExprNodeBase(ToDtype<int>()),
lhs_(lhs),
rhs_(rhs),
compare_op_(cmp_op) {}
};
// Encode an integer immediate value.
class IntImm : public ExprNode<IntImm> {
public:
int value() const {
return value_;
}
static Expr make(int value) {
return Expr(new IntImm(value));
}
private:
IntImm(int value) : ExprNodeBase(kInt32), value_(value) {}
int value_;
};
// Encode an fp32 immediate value.
class FloatImm : public ExprNode<FloatImm> {
public:
float value() const {
return value_;
}
static Expr make(float value) {
return Expr(new FloatImm(value));
}
private:
FloatImm(float value) : ExprNodeBase(kFloat32), value_(value) {}
float value_;
};
// The underlying representation node to a Variable.
// Currently, each Variable object represents a unique variable, even though the
// names might be the same. We should consider add a unique_name as well.
class Variable : public ExprNode<Variable> {
public:
static Expr make(const std::string& name_hint, Dtype dtype) {
return Expr(new Variable(name_hint, dtype));
}
static Expr make(Dtype dtype) {
return Expr(new Variable("", dtype));
}
// TODO: unique_name
const std::string& name_hint() const {
return name_hint_;
}
private:
Variable(const std::string& name_hint, Dtype dtype)
: ExprNodeBase(dtype), name_hint_(name_hint) {}
std::string name_hint_;
};
// An expression to construct the underlying variable node.
// Note: do not store any info here, since it is often possible to slice this
// object. For example: Var x('x'); Expr x2 = x;
class Var : public Expr {
public:
Var() : Expr(nullptr) {}
explicit Var(Dtype dtype) : Expr(Variable::make(dtype)) {}
Var(const std::string& name_hint, Dtype dtype)
: Expr(Variable::make(name_hint, dtype)) {}
explicit Var(Variable* node) : Expr(node) {}
const Variable* node() const {
return static_cast<const Variable*>(Expr::node());
}
bool operator==(const Var& other) const {
return this->node() == other.node();
}
bool operator!=(const Var& other) const {
return !(*this == other);
}
const std::string& name_hint() const {
return this->node()->name_hint();
}
bool empty() const {
return (this->node() == nullptr);
}
};
// Bind the value to the var and evaluate the body.
class Let : public ExprNode<Let> {
public:
const Expr& var() const {
return var_;
}
const Expr& value() const {
return value_;
}
const Expr& body() const {
return body_;
}
static Expr make(const Expr& var, const Expr& value, const Expr& body) {
return Expr(new Let(var, value, body));
}
private:
Let(const Expr& var, const Expr& value, const Expr& body)
: ExprNodeBase(body.dtype()), var_(var), value_(value), body_(body) {}
Expr var_;
Expr value_;
Expr body_;
};
class Block : public StmtNode<Block> {
public:
static Stmt make(const std::vector<Stmt>& stmts) {
std::vector<Stmt> valid_stmts;
for (size_t i = 0; i < stmts.size(); i++) {
if (stmts[i].empty()) {
continue;
}
valid_stmts.push_back(stmts[i]);
}
if (valid_stmts.empty()) {
return Stmt();
}
return Stmt(new Block(valid_stmts));
}
int nstmts() const {
return stmts_.size();
}
const Stmt& stmt(int index) const {
return stmts_[index];
}
private:
explicit Block(const std::vector<Stmt>& stmts) : stmts_(stmts) {}
std::vector<Stmt> stmts_;
};
class LoopOptions {
public:
// GPU Block Index
bool is_gpu_block_index() const {
return gpu_block_index_ != -1;
}
bool gpu_block_index() const {
return gpu_block_index_;
}
std::string gpu_block_index_str() const {
DCHECK(is_gpu_block_index());
static const char* kBlockIndexNames[] = {
"blockIdx.x",
"blockIdx.y",
"blockIdx.z",
"blockIdx.w",
};
DCHECK(gpu_block_index_ >= 0 && gpu_block_index_ < 4);
return kBlockIndexNames[gpu_block_index_];
}
void set_gpu_block_index(int index) {
if (is_gpu_thread_index()) {
throw std::runtime_error("Cannot set both gpu block and thread index");
}
if (is_gpu_block_index() && gpu_block_index() != index) {
throw std::runtime_error(
"Cannot set a previously set block index: " +
std::to_string(gpu_block_index()) + " vs " + std::to_string(index));
}
gpu_block_index_ = index;
}
// GPU Thread Index
bool is_gpu_thread_index() const {
return gpu_thread_index() != -1;
}
int gpu_thread_index() const {
return gpu_thread_index_;
}
std::string gpu_thread_index_str() const {
DCHECK(is_gpu_thread_index());
static const char* kThreadIndexNames[] = {
"threadIdx.x", "threadIdx.y", "threadIdx.z", "threadIdx.w"};
DCHECK(gpu_thread_index_ >= 0 && gpu_thread_index_ < 4);
return kThreadIndexNames[gpu_thread_index_];
}
void set_gpu_thread_index(int index) {
if (is_gpu_block_index()) {
throw std::runtime_error("Cannot set both gpu thread and block index");
}
if (is_gpu_thread_index() && gpu_thread_index() != index) {
throw std::runtime_error(
"Cannot set a previously set thread index: " +
std::to_string(gpu_thread_index()) + " vs " + std::to_string(index));
}
gpu_thread_index_ = index;
}
std::string ToString() const {
std::ostringstream oss;
if (is_gpu_block_index()) {
oss << gpu_block_index_str();
} else if (is_gpu_thread_index()) {
oss << gpu_thread_index_str();
}
return oss.str();
}
private:
int gpu_block_index_ = -1;
int gpu_thread_index_ = -1;
};
class For : public StmtNode<For> {
public:
const Var& var() const {
return var_;
}
const Expr& start() const {
return start_;
}
const Expr& stop() const {
return stop_;
}
const Stmt& body() const {
return body_;
}
static Stmt make(
const Var& var,
const Expr& start,
const Expr& stop,
const Stmt& body) {
if (body.empty()) {
return Stmt();
}
return Stmt(new For(var, start, stop, body));
}
static Stmt make(
const Var& var,
const Expr& start,
const Expr& stop,
const Stmt& body,
const LoopOptions& loop_options) {
if (body.empty()) {
return Stmt();
}
return Stmt(new For(var, start, stop, body, loop_options));
}
const LoopOptions loop_options() const {
return loop_options_;
}
private:
For(const Var& var, const Expr& start, const Expr& stop, const Stmt& body)
: var_(var), start_(start), stop_(stop), body_(body) {}
For(const Var& var,
const Expr& start,
const Expr& stop,
const Stmt& body,
const LoopOptions& loop_options)
: var_(var),
start_(start),
stop_(stop),
body_(body),
loop_options_(loop_options) {}
Var var_;
Expr start_;
Expr stop_;
Stmt body_;
LoopOptions loop_options_;
};
// Represents a ramp vector node:
// [base, base + 1 * stride, ... , base + (lanes - 1) * stride]
class Ramp : public ExprNode<Ramp> {
public:
const Expr& base() const {
return base_;
}
const Expr& stride() const {
return stride_;
}
static Expr make(const Expr& base, const Expr& stride, int lanes) {
return Expr(new Ramp(base, stride, lanes));
}
int lanes() const {
return lanes_;
}
private:
Ramp(const Expr& base, const Expr& stride, int lanes)
: ExprNodeBase(Dtype(base.dtype(), lanes)),
base_(base),
stride_(stride),
lanes_(lanes) {
CHECK_EQ(stride.dtype(), base.dtype());
}
Expr base_;
Expr stride_;
int lanes_;
};
class TORCH_API Load : public ExprNode<Load> {
public:
const Var& base_handle() const {
return base_handle_;
}
const Expr& index() const {
return index_;
}
const Expr& mask() const {
return mask_;
}
static Expr make(const Buffer& buffer, const Expr& index, const Expr& mask) {
return Expr(new Load(buffer, index, mask));
}
static Expr make(
Dtype dtype,
const Var& base_handle,
const Expr& index,
const Expr& mask) {
return Expr(new Load(dtype, base_handle, index, mask));
}
private:
Load(const Buffer& buffer, const Expr& index, const Expr& mask);
Load(
Dtype dtype,
const Var& base_handle,
const Expr& index,
const Expr& mask);
Var base_handle_;
Expr index_;
Expr mask_;
};
class TORCH_API Store : public StmtNode<Store> {
public:
const Var& base_handle() const {
return base_handle_;
}
const Expr& index() const {
return index_;
}
const Expr& value() const {
return value_;
}
const Expr& mask() const {
return mask_;
}
static Stmt make(
const Buffer& buffer,
const Expr& index,
const Expr& value,
const Expr& mask) {
return Stmt(new Store(buffer, index, value, mask));
}
static Stmt make(
const Var& base_handle,
const Expr& index,
const Expr& value,
const Expr& mask) {
return Stmt(new Store(base_handle, index, value, mask));
}
static Stmt make(
const Var& base_handle,
const Expr& index,
const Expr& value) {
return Stmt(new Store(base_handle, index, value, Expr(1)));
}
private:
// TODO: merge this with Load.
Store(
const Buffer& buffer,
const Expr& index,
const Expr& value,
const Expr& mask);
Store(
const Var& base_handle,
const Expr& index,
const Expr& value,
const Expr& mask)
: base_handle_(base_handle), index_(index), value_(value), mask_(mask) {
CHECK_EQ(base_handle_.dtype(), kHandle);
CHECK_EQ(index.dtype().lanes(), mask.dtype().lanes());
CHECK_EQ(index.dtype().lanes(), value.dtype().lanes());
CHECK_EQ(index.dtype().scalar_type(), kInt32);
}
Var base_handle_;
Expr index_;
Expr value_;
Expr mask_;
};
class Broadcast : public ExprNode<Broadcast> {
public:
const Expr& value() const {
return value_;
}
int lanes() const {
return lanes_;
}
static Expr make(const Expr& value, int lanes) {
return Expr(new Broadcast(value, lanes));
}
private:
Broadcast(const Expr& value, int lanes)
: ExprNodeBase(Dtype(value.dtype(), lanes)),
value_(value),
lanes_(lanes) {}
Expr value_;
int lanes_;
};
class IfThenElse : public ExprNode<IfThenElse> {
public:
const Expr& condition() const {
return condition_;
}
// Lazily evaluated only if condition is true
const Expr& true_value() const {
return true_;
}
// Lazily evaluated only if condition is false
const Expr& false_value() const {
return false_;
}
static Expr make(const Expr& c, const Expr& t, const Expr& f) {
return Expr(new IfThenElse(c, t, f));
}
private:
IfThenElse(const Expr& c, const Expr& t, const Expr& f)
: ExprNodeBase(t.dtype()), condition_(c), true_(t), false_(f) {
CHECK_EQ(c.dtype().scalar_type(), kInt32);
CHECK_EQ(c.dtype().lanes(), 1);
CHECK_EQ(t.dtype(), f.dtype());
}
Expr condition_;
Expr true_;
Expr false_;
};
class BaseCallNode : public BaseExprNode {
public:
enum CallType {
kFunctionCall,
};
int nparams() const {
return params_.size();
}
Expr& param(int index) {
return params_[index];
}
const Expr& param(int index) const {
return params_[index];
}
const std::vector<Expr>& params() const {
return params_;
}
virtual std::string func_name() const = 0;
CallType call_type() const {
return call_type_;
}
protected:
BaseCallNode(Dtype dtype, CallType call_type, const std::vector<Expr>& params)
: BaseExprNode(dtype), call_type_(call_type), params_(params) {}
private:
// The handler for the default ir_mutator to make a copy of this node with new
// params.
virtual Expr DefaultMutator(const std::vector<Expr>& new_params) const = 0;
template <class U, class B>
friend class ExprNode;
friend class IRMutator;
CallType call_type_;
std::vector<Expr> params_;
};
template <typename Op>
class CallNode : public ExprNode<Op, BaseCallNode> {
public:
using BaseClass = ExprNode<Op, BaseCallNode>;
using BaseClass::BaseClass;
};
class FunctionCall;
// Allocate a buffer of given shapes and dtypes and bind it with the given
// buffer var. The life span is at most through the current program, until it is
// explicitly freed. An unfreed memory is likely considered an error.
class Allocate : public StmtNode<Allocate> {
public:
static Stmt make(
const Var& buffer_var,
Dtype dtype,
const std::vector<Expr>& dims) {
return Stmt(new Allocate(buffer_var, dtype, dims));
}
const Var& buffer_var() const {
return buffer_var_;
}
Dtype dtype() const {
return dtype_;
}
const std::vector<Expr>& dims() const {
return dims_;
}
private:
Allocate(const Var& buffer_var, Dtype dtype, const std::vector<Expr>& dims)
: buffer_var_(buffer_var), dtype_(dtype), dims_(dims) {}
Var buffer_var_;
Dtype dtype_;
std::vector<Expr> dims_;
// TODO: add memory types.
};
// Free the specific buffer. It is an error.
class Free : public StmtNode<Free> {
public:
static Stmt make(const Var& buffer_var) {
return Stmt(new Free(buffer_var));
}
const Var& buffer_var() const {
return buffer_var_;
}
private:
Free(const Var& buffer_var) : buffer_var_(buffer_var) {}
Var buffer_var_;
};
class Cond : public StmtNode<Cond> {
public:
static Stmt make(
const Expr& condition,
const Stmt& true_stmt,
const Stmt& false_stmt) {
return Stmt(new Cond(condition, true_stmt, false_stmt));
}
const Expr& condition() const {
return condition_;
}
const Stmt& true_stmt() const {
return true_stmt_;
}
const Stmt& false_stmt() const {
return false_stmt_;
}
private:
Cond(const Expr& condition, const Stmt& true_stmt, const Stmt& false_stmt)
: condition_(condition), true_stmt_(true_stmt), false_stmt_(false_stmt) {}
Expr condition_;
Stmt true_stmt_;
Stmt false_stmt_;
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,48 @@
#include "torch/csrc/jit/tensorexpr/unique_name_manager.h"
#include <cctype>
#include "torch/csrc/jit/tensorexpr/ir.h"
namespace torch {
namespace jit {
namespace tensorexpr {
const std::string& UniqueNameManager::get_unique_name(const Variable* v) {
// Find if we have already encountered this variable.
auto iter = unique_name_mapping_.find(v);
if (iter != unique_name_mapping_.end()) {
return iter->second;
}
// First use the name_hint as a prefix to check if there is another name
// with the same prefix.
std::string name_hint = v->name_hint();
if (name_hint == "") {
name_hint = "v";
} else if (std::isdigit(name_hint[0])) {
name_hint = "v" + name_hint;
}
int& count = unique_name_count_[name_hint];
while (true) {
// Even if with a new count, this name might already be used. For example
// ("x", 1) could collidewith ("x_1", 0)
int count_v = count++;
std::string unique_name = name_hint;
if (count_v > 0) {
unique_name += "_" + std::to_string(count_v);
}
if (all_unique_names_.count(unique_name) == 0) {
all_unique_names_.insert(unique_name);
auto result = unique_name_mapping_.insert(std::make_pair(v, unique_name));
return result.first->second;
}
}
}
const std::string& UniqueNameManager::get_unique_name(const Var& v) {
return get_unique_name(v.node());
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,36 @@
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch {
namespace jit {
namespace tensorexpr {
class Var;
class Variable;
using VarNameMap = std::unordered_map<const Variable*, std::string>;
// A manager to get unique names from vars.
// It starts with the name hints of the var and append "_" + $counter until it
// hits a unique name.
class TORCH_API UniqueNameManager {
public:
const std::string& get_unique_name(const Var& v);
const std::string& get_unique_name(const Variable* v);
private:
friend class ScopedVarName;
VarNameMap unique_name_mapping_;
std::unordered_map<std::string, int> unique_name_count_;
std::unordered_set<std::string> all_unique_names_;
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch