[TensorExpr] Add IR visitor, IR mutator, and IR evaluator. (#33219)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33219

Test Plan: Imported from OSS

Differential Revision: D19848381

Pulled By: ZolotukhinM

fbshipit-source-id: 44ca7cd99c25e290a8ffd8146785c19f9c785dfd
This commit is contained in:
Mikhail Zolotukhin
2020-02-21 13:06:13 -08:00
committed by Facebook Github Bot
parent 49af9425a7
commit fc70fc3610
13 changed files with 1543 additions and 0 deletions

View File

@ -456,9 +456,13 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/vararg_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/codegen.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/expr.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/eval.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/types.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/unique_name_manager.cpp
)

View File

@ -2,6 +2,7 @@
#include "test/cpp/tensorexpr/test_utils.h"
#include "torch/csrc/jit/tensorexpr/buffer.h"
#include "torch/csrc/jit/tensorexpr/eval.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
#include <cmath>
@ -14,6 +15,53 @@ namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
void testExprBasicValueTest() {
KernelScope kernel_scope;
Expr a = IntImm::make(2), b = IntImm::make(3);
Expr c = Add::make(a, b);
SimpleIRExprEval eval(c);
EXPECT_EQ(eval.value<int>(), 5);
}
void testExprBasicValueTest02() {
KernelScope kernel_scope;
Expr a(2.0f);
Expr b(3.0f);
Expr c(4.0f);
Expr d(5.0f);
Expr f = (a + b) - (c + d);
SimpleIRExprEval eval(f);
EXPECT_EQ(eval.value<float>(), -4.0f);
}
void testExprLetTest01() {
KernelScope kernel_scope;
Var x("x", kFloat32);
Expr value = Expr(3.f);
Expr body = Expr(2.f) + (x * Expr(3.f) + Expr(4.f));
Expr result = Let::make(x, Expr(3.f), body);
SimpleIRExprEval eval(result);
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
}
void testExprLetTest02() {
KernelScope kernel_scope;
Var x("x", kFloat32);
Var y("y", kFloat32);
Expr value = Expr(3.f);
Expr body = Expr(2.f) + (x * Expr(3.f) + Expr(4.f) * y);
Expr e1 = Let::make(x, Expr(3.f), body);
Expr e2 = Let::make(y, Expr(6.f), e1);
SimpleIRExprEval eval(e2);
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
}
static Expr test_01(const Expr& expr) {
return expr;
}
void testExprVectorAdd01() {
KernelScope kernel_scope;
const int kVectorSize = 8;
@ -54,5 +102,63 @@ void testExprVectorAdd01() {
EXPECT_EQ(value.dtype(), Dtype(kFloat32, kVectorSize));
}
void testExprCompareSelectEQ() {
KernelScope kernel_scope;
constexpr int N = 1024;
Buffer a(Var("A", kHandle), kInt32, {N});
Buffer b(Var("B", kHandle), kInt32, {N});
Buffer c(Var("C", kHandle), kInt32, {N});
std::vector<int> a_buffer(N, 1);
std::vector<int> b_buffer(N, 1);
std::vector<int> c_buffer(N, 0);
std::vector<int> c_ref(N, 0);
auto mask = IntImm::make(1);
Var i("i", kInt32);
auto memcpy_expr = For::make(
i,
0,
N,
Store::make(
c,
i,
CompareSelect::make(
Load::make(a, i, mask),
Load::make(b, i, mask),
CompareSelectOperation::kEQ),
mask));
SimpleIREvaluator ir_eval(memcpy_expr, a, b, c);
ir_eval(a_buffer, b_buffer, c_buffer);
ASSERT_EQ(a_buffer.size(), N);
ASSERT_EQ(b_buffer.size(), N);
ASSERT_EQ(c_buffer.size(), N);
assertAllEqual(a_buffer, 1);
assertAllEqual(b_buffer, 1);
assertAllEqual(c_buffer, 1);
}
void testExprDynamicShapeAdd() {
KernelScope kernel_scope;
auto testWithSize = [](int32_t size) {
Var n("n", kInt32);
Buffer a(Var("a", kHandle), kFloat32, {n});
Buffer b(Var("b", kHandle), kFloat32, {n});
Buffer c(Var("c", kHandle), kFloat32, {n});
Var i("i", kInt32);
Stmt s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
std::vector<float> aData(size, 1.0f);
std::vector<float> bData(size, 2.0f);
std::vector<float> cData(size, 0.0f);
SimpleIREvaluator(s, a, b, c, n)(aData, bData, cData, size);
ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
};
testWithSize(1);
testWithSize(16);
testWithSize(37);
}
} // namespace jit
} // namespace torch

View File

@ -9,7 +9,13 @@
namespace torch {
namespace jit {
#define TH_FORALL_TESTS(_) \
_(ExprBasicValueTest) \
_(ExprBasicValueTest02) \
_(ExprLetTest01) \
_(ExprLetTest02) \
_(ExprVectorAdd01) \
_(ExprCompareSelectEQ) \
_(ExprDynamicShapeAdd) \
_(TypeTest01) \
#define TH_FORALL_TESTS_CUDA(_) \

View File

@ -190,16 +190,24 @@ libtorch_sources = [
"torch/csrc/jit/mobile/register_mobile_ops.cpp",
"torch/csrc/jit/mobile/interpreter.cpp",
"torch/csrc/jit/mobile/type_parser.cpp",
"torch/csrc/jit/tensorexpr/codegen.cpp",
"torch/csrc/jit/tensorexpr/eval.cpp",
"torch/csrc/jit/tensorexpr/expr.cpp",
"torch/csrc/jit/tensorexpr/ir.cpp",
"torch/csrc/jit/tensorexpr/ir_mutator.cpp",
"torch/csrc/jit/tensorexpr/ir_visitor.cpp",
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
"torch/csrc/jit/tensorexpr/types.cpp",
"torch/csrc/jit/tensorexpr/unique_name_manager.cpp",
"torch/csrc/utils/byte_order.cpp",
"torch/csrc/utils/tensor_flatten.cpp",
"torch/csrc/utils/variadic.cpp",
"torch/csrc/jit/tensorexpr/codegen.cpp",
"torch/csrc/jit/tensorexpr/eval.cpp",
"torch/csrc/jit/tensorexpr/expr.cpp",
"torch/csrc/jit/tensorexpr/ir.cpp",
"torch/csrc/jit/tensorexpr/ir_mutator.cpp",
"torch/csrc/jit/tensorexpr/ir_visitor.cpp",
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
"torch/csrc/jit/tensorexpr/types.cpp",
"torch/csrc/jit/tensorexpr/unique_name_manager.cpp",

View File

@ -0,0 +1,51 @@
#include "torch/csrc/jit/tensorexpr/codegen.h"
#include <sstream>
namespace torch {
namespace jit {
namespace tensorexpr {
RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList::
FindStmtFactoryMethod(const std::string& name) {
auto iter = stmt_factory_methods_.find(name);
if (iter == stmt_factory_methods_.end()) {
std::ostringstream oss;
oss << "Invalid stmt codegen name: " << name << ". ";
oss << "Existing codegen names: [";
int index = 0;
for (const auto& entry : stmt_factory_methods_) {
if (index != 0) {
oss << ", ";
}
oss << entry.first;
index++;
}
oss << "]";
throw std::runtime_error(oss.str());
}
return iter->second;
}
void RegisterCodeGenList::AddStmtFactoryMethod(
const std::string& name,
const StmtFactoryMethod& stmt_factory_method) {
auto insert_ret =
stmt_factory_methods_.insert(std::make_pair(name, stmt_factory_method));
if (!insert_ret.second) {
throw std::runtime_error("Duplicated CodeGen names: " + name);
}
}
std::unique_ptr<CodeGen> CreateCodeGen(
const std::string& name,
const Stmt& stmt,
const std::vector<CodeGen::BufferArg>& params) {
RegisterCodeGenList::StmtFactoryMethod method =
RegisterCodeGenList::GetInstance().FindStmtFactoryMethod(name);
return method(stmt, params);
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,157 @@
#pragma once
#include "torch/csrc/jit/tensorexpr/buffer.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
namespace torch {
namespace jit {
namespace tensorexpr {
class CodeGen {
public:
class BufferArg;
class CallArg;
template <typename... Ts>
CodeGen(const Stmt& stmt, Ts... ts)
: stmt_(stmt), buffer_args_({BufferArg(ts)...}) {}
CodeGen(const Stmt& stmt, const std::vector<BufferArg>& buffer_args)
: stmt_(stmt), buffer_args_(buffer_args) {}
virtual ~CodeGen() {}
const Stmt& stmt() const {
return stmt_;
}
std::vector<BufferArg>& buffer_args() {
return buffer_args_;
}
const std::vector<BufferArg>& buffer_args() const {
return buffer_args_;
}
TORCH_API virtual void call(const std::vector<CallArg>& args) {
LOG(FATAL) << "unimplemented call";
}
private:
Stmt stmt_;
std::vector<BufferArg> buffer_args_;
};
class CodeGen::BufferArg {
public:
BufferArg(const Buffer& buffer)
: var_(buffer.data()), dtype_(buffer.dtype()) {}
BufferArg(const Var& var) : var_(var), dtype_(var.dtype()), isVar_(true) {}
const Var& var() const {
return var_;
}
Var& var() {
return var_;
}
Dtype dtype() const {
return dtype_;
}
bool isVar() const {
return isVar_;
}
private:
Var var_;
Dtype dtype_;
bool isVar_{false};
};
class CodeGen::CallArg {
public:
template <typename T>
CallArg(const std::vector<T>& buffer) : ptr_(const_cast<T*>(buffer.data())) {}
CallArg(void* ptr) : ptr_(ptr) {}
CallArg(int32_t i) : ival_(i) {}
CallArg(float f) : fval_(f) {}
void* data() const {
return ptr_;
}
int32_t intData() const {
return ival_;
}
float floatData() const {
return fval_;
}
int* intPtr() const {
return const_cast<int*>(&ival_);
}
float* floatPtr() const {
return const_cast<float*>(&fval_);
}
private:
union {
void* ptr_;
float fval_;
int32_t ival_;
};
};
class RegisterCodeGenList {
public:
TORCH_API static RegisterCodeGenList& GetInstance() {
static RegisterCodeGenList codegen_list;
return codegen_list;
}
using StmtFactoryMethod = std::function<std::unique_ptr<CodeGen>(
const Stmt& stmt,
const std::vector<CodeGen::BufferArg>&)>;
TORCH_API StmtFactoryMethod FindStmtFactoryMethod(const std::string& name);
private:
template <class CodeGenType>
friend class RegisterCodeGen;
RegisterCodeGenList() {}
TORCH_API void AddStmtFactoryMethod(
const std::string& name,
const StmtFactoryMethod& stmt_factory_method);
RegisterCodeGenList(const RegisterCodeGenList&) = delete;
RegisterCodeGenList& operator=(const RegisterCodeGenList&) = delete;
std::unordered_map<std::string, StmtFactoryMethod> stmt_factory_methods_;
};
template <class CodeGenType>
class RegisterCodeGen {
public:
explicit RegisterCodeGen(const std::string& name) {
RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance();
codegen_list.AddStmtFactoryMethod(
name,
[](const Stmt& stmt, const std::vector<CodeGen::BufferArg>& params) {
std::unique_ptr<CodeGen> method(new CodeGenType(stmt, params));
return method;
});
}
};
TORCH_API std::unique_ptr<CodeGen> CreateCodeGen(
const std::string& name,
const Stmt& stmt,
const std::vector<CodeGen::BufferArg>& params);
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,11 @@
#include "torch/csrc/jit/tensorexpr/eval.h"
namespace torch {
namespace jit {
namespace tensorexpr {
RegisterCodeGen<SimpleIREvaluator> reg("simple_ir_eval");
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,606 @@
#pragma once
#include <cmath>
#include <unordered_map>
#include <vector>
#include <c10/util/Logging.h>
#include "torch/csrc/jit/tensorexpr/buffer.h"
#include "torch/csrc/jit/tensorexpr/codegen.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
#include "torch/csrc/jit/tensorexpr/types.h"
namespace torch {
namespace jit {
namespace tensorexpr {
class Value {
public:
Value() : dtype_(kInt32) {
i32_values.push_back(0);
}
Value(int v) : dtype_(kInt32) {
i32_values.push_back(v);
}
Value(float v) : dtype_(kFloat32) {
f32_values.push_back(v);
}
Value(const std::vector<int>& v)
: dtype_(Dtype(kInt32, v.size())), i32_values(v) {}
Value(const std::vector<float>& v)
: dtype_(Dtype(kFloat32, v.size())), f32_values(v) {}
template <typename T>
T as() const;
template <typename T>
const std::vector<T>& as_vec() const;
Dtype dtype() const {
return dtype_;
}
private:
Dtype dtype_;
std::vector<int32> i32_values;
std::vector<float> f32_values;
void* ptr;
};
template <>
inline int Value::as<int>() const {
CHECK_EQ(dtype_, kInt32) << "invalid dtype";
return i32_values[0];
}
template <>
inline float Value::as<float>() const {
CHECK_EQ(dtype_, kFloat32) << "invalid dtype";
return f32_values[0];
}
template <>
inline const std::vector<float>& Value::as_vec<float>() const {
CHECK_EQ(dtype_.scalar_type(), kFloat32) << "invalid dtype";
return f32_values;
}
template <>
inline const std::vector<int>& Value::as_vec<int>() const {
CHECK_EQ(dtype_.scalar_type(), kInt32) << "invalid dtype";
return i32_values;
}
inline int mod_value(int lhs, int rhs) {
return lhs % rhs;
}
inline float mod_value(float lhs, float rhs) {
return std::fmod(lhs, rhs);
}
class SimpleIREvaluator : public CodeGen, public IRVisitor {
public:
using CodeGen::CodeGen;
~SimpleIREvaluator() override {}
TORCH_API void call(const std::vector<CallArg>& args) override {
CHECK_EQ(args.size(), buffer_args().size());
for (size_t i = 0; i < args.size(); i++) {
bind(buffer_args()[i], args[i]);
}
stmt().accept(this);
eval_context_.clear();
buffer_mapping_.clear();
internal_buffers_.clear();
}
void bind(const BufferArg& buf, const CallArg& data) {
if (buf.isVar()) {
if (buf.dtype() == kInt32) {
eval_context_[buf.var().node()] = data.intData();
} else if (buf.dtype() == kFloat32) {
eval_context_[buf.var().node()] = data.floatData();
} else {
LOG(FATAL) << "Unhandled dtype for argument " << buf.var().name_hint()
<< ": " << buf.dtype();
}
} else {
buffer_mapping_[buf.var().node()] = data.data();
}
}
template <typename... Ts>
void operator()(const Ts&... ts) {
std::vector<CallArg> args({CallArg(ts)...});
call(args);
}
TORCH_API void visit(const Add* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Sub* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Mul* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Div* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Mod* v) override {
visit_binary_op(v);
}
TORCH_API void visit(const Max* v) override {
visit_binary_op(v, v->propagate_nans());
}
TORCH_API void visit(const Min* v) override {
visit_binary_op(v, v->propagate_nans());
}
void visit(const CompareSelect* v) override {
visit_compare_select_op(v, v->compare_select_op());
}
template <typename T>
Value binary_op(
const Value& lhs,
const Value& rhs,
IRNodeType op_type,
bool option = false) {
std::vector<T> lhs_v = lhs.as_vec<T>();
std::vector<T> rhs_v = rhs.as_vec<T>();
std::vector<T> result_v(lhs_v.size());
for (size_t i = 0; i < lhs_v.size(); i++) {
switch (op_type) {
case IRNodeType::kAdd:
result_v[i] = lhs_v[i] + rhs_v[i];
break;
case IRNodeType::kSub:
result_v[i] = lhs_v[i] - rhs_v[i];
break;
case IRNodeType::kMul:
result_v[i] = lhs_v[i] * rhs_v[i];
break;
case IRNodeType::kDiv:
result_v[i] = lhs_v[i] / rhs_v[i];
break;
case IRNodeType::kMod:
result_v[i] = mod_value(lhs_v[i], rhs_v[i]);
break;
case IRNodeType::kMax:
if (lhs.dtype() == kFloat32 && rhs.dtype() == kFloat32 && option) {
// Propagate NaNs
if (std::isnan((float)lhs_v[i])) {
result_v[i] = lhs_v[i];
} else if (std::isnan((float)rhs_v[i])) {
result_v[i] = rhs_v[i];
}
} else {
result_v[i] = lhs_v[i] > rhs_v[i] ? lhs_v[i] : rhs_v[i];
}
break;
case IRNodeType::kMin:
if (lhs.dtype() == kFloat32 && rhs.dtype() == kFloat32 && option) {
// Propagate NaNs
if (std::isnan((float)lhs_v[i])) {
result_v[i] = lhs_v[i];
} else if (std::isnan((float)rhs_v[i])) {
result_v[i] = rhs_v[i];
}
} else {
result_v[i] = lhs_v[i] < rhs_v[i] ? lhs_v[i] : rhs_v[i];
}
break;
default:
// TODO: change to a proper error report
throw std::runtime_error("invalid operator type");
}
}
return Value(result_v);
}
template <typename T>
Value compare_select_op(
const Value& lhs,
const Value& rhs,
CompareSelectOperation cmp_op) {
std::vector<T> lhs_v = lhs.as_vec<T>();
std::vector<T> rhs_v = rhs.as_vec<T>();
std::vector<int> result_v(lhs_v.size());
for (size_t i = 0; i < lhs_v.size(); i++) {
switch (cmp_op) {
case CompareSelectOperation::kEQ:
result_v[i] = (lhs_v[i] == rhs_v[i]) ? 1 : 0;
break;
case CompareSelectOperation::kNE:
result_v[i] = (lhs_v[i] != rhs_v[i]) ? 1 : 0;
break;
case CompareSelectOperation::kGT:
result_v[i] = (lhs_v[i] > rhs_v[i]) ? 1 : 0;
break;
case CompareSelectOperation::kGE:
result_v[i] = (lhs_v[i] >= rhs_v[i]) ? 1 : 0;
break;
case CompareSelectOperation::kLT:
result_v[i] = (lhs_v[i] < rhs_v[i]) ? 1 : 0;
break;
case CompareSelectOperation::kLE:
result_v[i] = (lhs_v[i] <= rhs_v[i]) ? 1 : 0;
break;
default:
// TODO: change to a proper error report
throw std::runtime_error("invalid operator type");
}
}
return Value(result_v);
}
template <typename Op>
void visit_binary_op(const BinaryOpNode<Op>* v, bool option = false) {
v->lhs().accept(this);
Value lhs_v = value_;
v->rhs().accept(this);
Value rhs_v = value_;
CHECK_EQ(lhs_v.dtype(), rhs_v.dtype());
IRNodeType expr_type = v->expr_type();
if (lhs_v.dtype().scalar_type() == kFloat32) {
value_ = binary_op<float>(lhs_v, rhs_v, expr_type);
} else if (lhs_v.dtype().scalar_type() == kInt32) {
value_ = binary_op<int>(lhs_v, rhs_v, expr_type);
} else {
LOG(FATAL) << "invalid dtype: " << lhs_v.dtype();
}
}
void visit_compare_select_op(
const CompareSelect* v,
CompareSelectOperation cmp_op) {
v->lhs().accept(this);
Value lhs_v = value_;
v->rhs().accept(this);
Value rhs_v = value_;
CHECK_EQ(lhs_v.dtype(), rhs_v.dtype());
if (lhs_v.dtype().scalar_type() == kFloat32) {
value_ = compare_select_op<float>(lhs_v, rhs_v, cmp_op);
} else if (lhs_v.dtype().scalar_type() == kInt32) {
value_ = compare_select_op<int>(lhs_v, rhs_v, cmp_op);
} else {
LOG(FATAL) << "invalid dtype: " << lhs_v.dtype();
}
}
TORCH_API void visit(const IntImm* v) override {
value_ = Value(v->value());
}
TORCH_API void visit(const FloatImm* v) override {
value_ = Value(v->value());
}
TORCH_API void visit(const Let* v) override {
const Variable* var = v->var().AsNode<Variable>();
CHECK(var != nullptr);
v->value().accept(this);
Value value = value_;
auto iter = eval_context_.find(var);
// TODO: make the same value settable multiple times.
CHECK(iter == eval_context_.end())
<< "var must not exist in the context before";
eval_context_[var] = value_;
v->body().accept(this);
eval_context_.erase(var);
}
TORCH_API void visit(const Variable* v) override {
auto iter = eval_context_.find(v);
CHECK(iter != eval_context_.end())
<< "var must be defined in the context before";
value_ = iter->second;
}
TORCH_API void visit(const Cast* v) override {
const Expr& src_value = v->src_value();
src_value.accept(this);
Dtype dst_dtype = v->dtype();
Dtype src_dtype = src_value.dtype();
CHECK_EQ(src_dtype.lanes(), dst_dtype.lanes());
if (src_dtype != dst_dtype) {
if (src_dtype == kFloat32 && dst_dtype == kInt32) {
const std::vector<float>& src_values = value_.as_vec<float>();
std::vector<int> dst_values(src_values.size());
for (int i = 0; i < src_dtype.lanes(); ++i) {
dst_values[i] = static_cast<int>(src_values[i]);
}
this->value_ = Value(dst_values);
} else if (src_dtype == kInt32 && dst_dtype == kFloat32) {
const std::vector<int>& src_values = value_.as_vec<int>();
std::vector<float> dst_values(src_values.size());
for (int i = 0; i < src_dtype.lanes(); ++i) {
dst_values[i] = static_cast<float>(src_values[i]);
}
this->value_ = Value(dst_values);
}
}
}
TORCH_API void visit(const For* v) override {
const BaseExprNode* var_node = v->var().node();
v->start().accept(this);
int start = value_.as<int>();
v->stop().accept(this);
int stop = value_.as<int>();
auto iter = eval_context_.find(var_node);
CHECK(iter == eval_context_.end())
<< "var in For must not exist in eval context";
for (int i = start; i < stop; i++) {
eval_context_[var_node] = Value(i);
v->body().accept(this);
}
eval_context_.erase(var_node);
}
TORCH_API void visit(const Ramp* v) override {
v->base().accept(this);
int base = value().as<int>();
v->stride().accept(this);
int stride = value().as<int>();
int lanes = v->lanes();
std::vector<int> values(lanes);
for (int i = 0; i < lanes; i++) {
values[i] = base + i * stride;
}
value_ = Value(values);
}
TORCH_API void visit(const Broadcast* v) override {
v->value().accept(this);
Value value = this->value();
int lanes = v->lanes();
if (value.dtype() == kInt32) {
std::vector<int> v(lanes, value.as<int>());
value_ = Value(v);
} else if (value.dtype() == kFloat32) {
std::vector<float> v(lanes, value.as<float>());
value_ = Value(v);
} else {
LOG(FATAL) << "invalid dtype: " << value.dtype();
}
}
TORCH_API void visit(const IfThenElse* v) override {
v->condition().accept(this);
if (value_.as<int>()) {
v->true_value().accept(this);
} else {
v->false_value().accept(this);
}
}
TORCH_API void visit(const Load* v) override {
const Variable* base_node = v->base_handle().node();
auto iter = buffer_mapping_.find(base_node);
CHECK(iter != buffer_mapping_.end())
<< "missing buffer binding: " << base_node->name_hint();
void* ptr = iter->second;
v->index().accept(this);
std::vector<int> index = value().as_vec<int>();
v->mask().accept(this);
std::vector<int> mask = value().as_vec<int>();
Dtype v_sdtype = v->dtype().scalar_type();
if (v_sdtype == kFloat32) {
float* ptr_f = static_cast<float*>(ptr);
std::vector<float> v(index.size());
for (size_t i = 0; i < index.size(); i++) {
if (mask[i]) {
v[i] = ptr_f[index[i]];
}
}
value_ = Value(v);
} else if (v_sdtype == kInt32) {
int* ptr_i = static_cast<int*>(ptr);
std::vector<int> v(index.size());
for (size_t i = 0; i < index.size(); i++) {
if (mask[i]) {
v[i] = ptr_i[index[i]];
}
}
value_ = Value(v);
} else {
LOG(FATAL) << "Invalid dtype: " << v_sdtype;
}
}
TORCH_API void visit(const Store* v) override {
const Variable* base_node = v->base_handle().node();
auto iter = buffer_mapping_.find(base_node);
CHECK(iter != buffer_mapping_.end());
void* ptr = iter->second;
v->index().accept(this);
std::vector<int> index = value().as_vec<int>();
v->mask().accept(this);
std::vector<int> mask = value().as_vec<int>();
CHECK_EQ(index.size(), mask.size());
Dtype v_sdtype = v->value().dtype().scalar_type();
if (v_sdtype == kFloat32) {
v->value().accept(this);
std::vector<float> value = this->value().as_vec<float>();
CHECK_EQ(index.size(), value.size());
float* ptr_f = static_cast<float*>(ptr);
for (size_t i = 0; i < index.size(); i++) {
if (mask[i]) {
ptr_f[index[i]] = value[i];
}
}
} else if (v_sdtype == kInt32) {
v->value().accept(this);
std::vector<int> value = this->value().as_vec<int>();
CHECK_EQ(index.size(), value.size());
int* ptr_i = static_cast<int*>(ptr);
for (size_t i = 0; i < index.size(); i++) {
if (mask[i]) {
ptr_i[index[i]] = value[i];
}
}
} else {
LOG(FATAL) << "Invalid dtype: " << v_sdtype;
}
}
void visit(const Allocate* v) override {
const Variable* buffer_var = v->buffer_var().AsNode<Variable>();
std::vector<Expr> dims = v->dims();
int total_byte_size = v->dtype().byte_size();
for (size_t i = 0; i < dims.size(); i++) {
dims[i].accept(this);
total_byte_size *= value_.as<int>();
}
int int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int);
std::unique_ptr<std::vector<int>> buffer(new std::vector<int>(int_count));
auto iter = buffer_mapping_.find(buffer_var);
if (iter != buffer_mapping_.end() && iter->second != nullptr) {
throw std::runtime_error(
"Allocate a buffer that has already been allocated: " +
buffer_var->name_hint());
}
buffer_mapping_[buffer_var] = buffer->data();
internal_buffers_.insert(std::make_pair(buffer_var, std::move(buffer)));
}
void visit(const Free* v) override {
const Variable* buffer_var = v->buffer_var().AsNode<Variable>();
int count = internal_buffers_.erase(buffer_var);
if (count == 0) {
throw std::runtime_error(
"Free a buffer that is not currently bound: " +
buffer_var->name_hint());
}
}
void visit(const Cond* v) override {
v->condition().accept(this);
if (value().as<int>()) {
v->true_stmt().accept(this);
} else {
v->false_stmt().accept(this);
}
}
Value value() const {
return value_;
}
private:
Value value_;
std::unordered_map<const BaseExprNode*, Value> eval_context_;
std::unordered_map<const BaseExprNode*, void*> buffer_mapping_;
std::unordered_map<const Variable*, std::unique_ptr<std::vector<int>>>
internal_buffers_;
};
using VarMapping = std::vector<std::pair<Expr, Expr>>;
class VarSubMutator : public IRMutator {
public:
VarSubMutator(const VarMapping& var_mapping) {
for (const auto& entry : var_mapping) {
const Expr& key = entry.first;
const Expr& value = entry.second;
const Variable* key_var = key.AsNode<Variable>();
CHECK(key_var != nullptr);
var_mapping_[key_var] = value;
}
}
Expr mutate(const Variable* var) override {
auto iter = var_mapping_.find(var);
if (iter == var_mapping_.end()) {
return Expr(const_cast<Variable*>(var));
}
return iter->second;
}
private:
std::unordered_map<const Variable*, Expr> var_mapping_;
};
template <class CodeGenType>
class ExprEval {
public:
using BufferArg = CodeGen::BufferArg;
using CallArg = CodeGen::CallArg;
template <typename... Ts>
ExprEval(const Expr& expr, Ts... ts) : ExprEval(expr, {BufferArg(ts)...}) {}
ExprEval(const Expr& expr, const std::vector<BufferArg>& buffer_args)
: dtype_(expr.dtype()) {
std::vector<BufferArg> buffer_args_extended = buffer_args;
Buffer ret_buf("ret_val", dtype_, {1});
Stmt store_stmt = Store::make(ret_buf.data(), 0, expr);
buffer_args_extended.push_back(ret_buf);
codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended));
}
template <typename... Ts>
void operator()(Ts... ts) {
call(ts...);
}
void operator()(const std::vector<CallArg>& call_args) {
call(call_args);
}
template <typename... Ts>
void call(Ts... ts) {
call({CallArg(ts)...});
}
void call(const std::vector<CallArg>& call_args) {
std::vector<CallArg> call_args_extended = call_args;
if (dtype_ == kFloat32) {
std::vector<float> ret_val_arg(1);
call_args_extended.push_back(CallArg(ret_val_arg));
codegen_->call(call_args_extended);
ret_value_ = Value(ret_val_arg[0]);
} else if (dtype_ == kInt32) {
std::vector<int> ret_val_arg(1);
call_args_extended.push_back(CallArg(ret_val_arg));
codegen_->call(call_args_extended);
ret_value_ = Value(ret_val_arg[0]);
} else {
throw std::runtime_error("Invalid dtype");
}
}
template <typename T, typename... Ts>
T value(Ts... ts) {
call(std::forward<Ts>(ts)...);
return ret_value_.as<T>();
}
private:
Dtype dtype_;
std::unique_ptr<CodeGenType> codegen_;
Value ret_value_;
};
inline Expr Substitute(Expr* expr, const VarMapping& var_mapping) {
VarSubMutator var_sub(var_mapping);
return expr->accept_mutator(&var_sub);
}
inline Stmt Substitute(Stmt* stmt, const VarMapping& var_mapping) {
VarSubMutator var_sub(var_mapping);
return stmt->accept_mutator(&var_sub);
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -5,6 +5,8 @@
*/
#pragma once
#include "torch/csrc/jit/tensorexpr/ir_mutator.h"
#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
#include "torch/csrc/jit/tensorexpr/types.h"
#include "torch/csrc/jit/tensorexpr/mem_arena.h"
@ -20,6 +22,8 @@ class BaseExprNode : public KernelScopedObject {
Dtype dtype() const {
return dtype_;
}
TORCH_API virtual void accept(IRVisitor* visitor) const = 0;
virtual Expr accept_mutator(IRMutator* mutator) = 0;
private:
Dtype dtype_;
@ -29,6 +33,8 @@ class BaseExprNode : public KernelScopedObject {
class BaseStmtNode : public KernelScopedObject {
public:
BaseStmtNode() {}
TORCH_API virtual void accept(IRVisitor* visitor) const = 0;
virtual Stmt accept_mutator(IRMutator* mutator) = 0;
};
// A CRTP pattern to accept visitors for children class,
@ -37,6 +43,10 @@ template <class Op, class Base = BaseExprNode>
class ExprNode : public Base {
public:
using ExprNodeBase = ExprNode<Op>;
void accept(IRVisitor* visitor) const override {
visitor->visit(static_cast<const Op*>(this));
}
Expr accept_mutator(IRMutator* mutator) override;
// pass the constructor to the base class
using Base::Base;
};
@ -45,6 +55,10 @@ template <class Op>
class StmtNode : public BaseStmtNode {
public:
using StmtNodeBase = StmtNode<Op>;
void accept(IRVisitor* visitor) const override {
visitor->visit(static_cast<const Op*>(this));
}
Stmt accept_mutator(IRMutator* mutator) override;
StmtNode() {}
};
@ -68,6 +82,23 @@ class TORCH_API Expr {
return base_expr_node_ == nullptr;
}
void accept(IRVisitor* visitor) const {
// TODO: Consider implement this without using recursion. Otherwise,
// if the expression tree is degenerate and too long, it could cause a
// stack overflow.
if (node() == nullptr) {
return;
}
node()->accept(visitor);
}
Expr accept_mutator(IRMutator* mutator) {
if (node() == nullptr) {
return Expr();
}
return node()->accept_mutator(mutator);
}
Expr(int v);
Expr(float v);
@ -115,6 +146,20 @@ class Stmt {
return base_stmt_node_;
}
void accept(IRVisitor* visitor) const {
if (node() == nullptr) {
return;
}
node()->accept(visitor);
}
Stmt accept_mutator(IRMutator* mutator) {
if (node() == nullptr) {
return Stmt();
}
return node()->accept_mutator(mutator);
}
bool empty() const {
return node() == nullptr;
}
@ -128,6 +173,18 @@ class Stmt {
BaseStmtNode* base_stmt_node_ = nullptr;
};
template <class Op, class Base>
Expr ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
ExprNode* this_mutable = const_cast<ExprNode*>(this);
return mutator->mutate(static_cast<Op*>(this_mutable));
}
template <class Op>
Stmt StmtNode<Op>::accept_mutator(IRMutator* mutator) {
StmtNode* this_mutable = const_cast<StmtNode*>(this);
return mutator->mutate(static_cast<Op*>(this_mutable));
}
inline bool same_node(const Expr& expr1, const Expr& expr2) {
return expr1.AsNode<BaseExprNode>() == expr2.AsNode<BaseExprNode>();
}

View File

@ -0,0 +1,272 @@
#include "torch/csrc/jit/tensorexpr/ir_mutator.h"
#include "torch/csrc/jit/tensorexpr/eval.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
namespace torch {
namespace jit {
namespace tensorexpr {
template <typename Op>
static Expr mutate_binary_op(
const BinaryOpNode<Op>* v,
IRMutator* mutator,
bool option = false) {
Expr lhs = v->lhs();
Expr rhs = v->rhs();
Expr lhs_new = lhs.accept_mutator(mutator);
Expr rhs_new = rhs.accept_mutator(mutator);
if (same_node(lhs, lhs_new) && same_node(rhs, rhs_new)) {
return Expr(v);
}
IRNodeType expr_type = v->expr_type();
switch (expr_type) {
case IRNodeType::kAdd:
return Add::make(lhs_new, rhs_new);
case IRNodeType::kSub:
return Sub::make(lhs_new, rhs_new);
case IRNodeType::kMul:
return Mul::make(lhs_new, rhs_new);
case IRNodeType::kDiv:
return Div::make(lhs_new, rhs_new);
case IRNodeType::kMod:
return Mod::make(lhs_new, rhs_new);
case IRNodeType::kMax:
return Max::make(lhs_new, rhs_new, option);
case IRNodeType::kMin:
return Min::make(lhs_new, rhs_new, option);
default:
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
return Expr();
}
}
Expr IRMutator::mutate(const Add* v) {
return mutate_binary_op(v, this);
}
Expr IRMutator::mutate(const Sub* v) {
return mutate_binary_op(v, this);
}
Expr IRMutator::mutate(const Mul* v) {
return mutate_binary_op(v, this);
}
Expr IRMutator::mutate(const Div* v) {
return mutate_binary_op(v, this);
}
Expr IRMutator::mutate(const Mod* v) {
return mutate_binary_op(v, this);
}
Expr IRMutator::mutate(const Max* v) {
return mutate_binary_op(v, this, v->propagate_nans());
}
Expr IRMutator::mutate(const Min* v) {
return mutate_binary_op(v, this, v->propagate_nans());
}
Expr IRMutator::mutate(const CompareSelect* v) {
Expr lhs = v->lhs();
Expr rhs = v->rhs();
Expr lhs_new = lhs.accept_mutator(this);
Expr rhs_new = rhs.accept_mutator(this);
if (same_node(lhs, lhs_new) && same_node(rhs, rhs_new)) {
return Expr(v);
}
return CompareSelect::make(lhs_new, rhs_new, v->compare_select_op());
}
Expr IRMutator::mutate(const IntImm* v) {
return Expr(v);
}
Expr IRMutator::mutate(const FloatImm* v) {
return Expr(v);
}
Expr IRMutator::mutate(const Cast* v) {
Expr src_value = v->src_value();
Expr src_value_new = src_value.accept_mutator(this);
if (same_node(src_value_new, v->src_value())) {
return Expr(v);
}
return Cast::make(v->dtype(), src_value_new);
}
Expr IRMutator::mutate(const Variable* v) {
return Expr(v);
}
Expr IRMutator::mutate(const Let* v) {
Expr var = v->var();
Expr value = v->value();
Expr body = v->body();
Expr var_new = var.accept_mutator(this);
Expr value_new = value.accept_mutator(this);
Expr body_new = body.accept_mutator(this);
if (same_node(var, var_new) && same_node(value, value_new) &&
same_node(body, body_new)) {
return Expr(v);
}
return Let::make(var_new, value_new, body_new);
}
Expr IRMutator::mutate(const Ramp* v) {
Expr base = v->base();
Expr stride = v->stride();
Expr base_new = base.accept_mutator(this);
Expr stride_new = stride.accept_mutator(this);
if (same_node(base, base_new) && same_node(stride, stride_new)) {
return Expr(v);
}
return Ramp::make(base_new, stride_new, v->lanes());
}
Expr IRMutator::mutate(const Load* v) {
Dtype dtype = v->dtype();
Var base_handle = v->base_handle();
Expr index = v->index();
Expr mask = v->mask();
Expr base_handle_expr = base_handle.accept_mutator(this);
Var base_handle_new = Var(base_handle_expr.AsNode<Variable>());
Expr index_new = index.accept_mutator(this);
Expr mask_new = mask.accept_mutator(this);
if (same_node(base_handle, base_handle_new) && same_node(index, index_new) &&
same_node(mask, mask_new)) {
return Expr(v);
}
return Load::make(dtype, base_handle_new, index_new, mask_new);
}
Expr IRMutator::mutate(const Broadcast* v) {
Expr value = v->value();
int lanes = v->lanes();
Expr value_new = value.accept_mutator(this);
if (same_node(value, value_new)) {
return Expr(v);
}
return Broadcast::make(value_new, lanes);
}
Expr IRMutator::mutate(const IfThenElse* v) {
Expr condition = v->condition();
Expr true_value = v->true_value();
Expr false_value = v->false_value();
Expr condition_new = condition.accept_mutator(this);
Expr true_value_new = true_value.accept_mutator(this);
Expr false_value_new = false_value.accept_mutator(this);
if (same_node(condition, condition_new) &&
same_node(true_value, true_value_new) &&
same_node(false_value, false_value_new)) {
return Expr(v);
}
return IfThenElse::make(condition_new, true_value_new, false_value_new);
}
Stmt IRMutator::mutate(const For* v) {
Var var = v->var();
Expr start = v->start();
Expr stop = v->stop();
Stmt body = v->body();
LoopOptions loop_options = v->loop_options();
Expr var_new_expr = var.accept_mutator(this);
Var var_new = Var(var_new_expr.AsNode<Variable>());
Expr start_new = start.accept_mutator(this);
Expr stop_new = stop.accept_mutator(this);
Stmt body_new = body.accept_mutator(this);
if (same_node(var, var_new) && same_node(start, start_new) &&
same_node(stop, stop_new) && same_node(body, body_new)) {
return Stmt(v);
}
return For::make(var_new, start_new, stop_new, body_new, loop_options);
}
Stmt IRMutator::mutate(const Block* v) {
bool any_change = false;
std::vector<Stmt> stmts;
for (int i = 0; i < v->nstmts(); i++) {
Stmt stmt = v->stmt(i);
Stmt stmt_new = stmt.accept_mutator(this);
if (!same_node(stmt, stmt_new)) {
any_change = true;
}
stmts.push_back(stmt_new);
}
if (!any_change) {
return Stmt(v);
}
return Block::make(stmts);
}
Stmt IRMutator::mutate(const Store* v) {
Var base_handle = v->base_handle();
Expr index = v->index();
Expr value = v->value();
Expr mask = v->mask();
Expr base_handle_expr = base_handle.accept_mutator(this);
Var base_handle_new = Var(base_handle_expr.AsNode<Variable>());
Expr index_new = index.accept_mutator(this);
Expr value_new = value.accept_mutator(this);
Expr mask_new = mask.accept_mutator(this);
if (same_node(base_handle, base_handle_new) && same_node(index, index_new) &&
same_node(value, value_new) && same_node(mask, mask_new)) {
return Stmt(v);
}
return Store::make(base_handle_new, index_new, value_new, mask_new);
}
Stmt IRMutator::mutate(const Allocate* v) {
Var buffer_var_old = v->buffer_var();
Var buffer_var_new =
Var(buffer_var_old.accept_mutator(this).AsNode<Variable>());
bool any_change = same_node(buffer_var_new, buffer_var_old);
std::vector<Expr> dims_old = v->dims();
std::vector<Expr> dims_new(dims_old.size());
for (size_t i = 0; i < dims_old.size(); i++) {
dims_new[i] = dims_old[i].accept_mutator(this);
any_change |= same_node(dims_new[i], dims_old[i]);
}
if (!any_change) {
return Stmt(v);
}
return Allocate::make(buffer_var_new, v->dtype(), dims_new);
}
Stmt IRMutator::mutate(const Free* v) {
Var buffer_var_old = v->buffer_var();
Var buffer_var_new =
Var(buffer_var_old.accept_mutator(this).AsNode<Variable>());
if (same_node(buffer_var_new, buffer_var_old)) {
return Stmt(v);
}
return Free::make(buffer_var_new);
}
Stmt IRMutator::mutate(const Cond* v) {
Expr cond_old = v->condition();
Stmt true_old = v->true_stmt();
Stmt false_old = v->false_stmt();
Expr cond_new = cond_old.accept_mutator(this);
Stmt true_new = true_old.accept_mutator(this);
Stmt false_new = false_old.accept_mutator(this);
if (same_node(cond_old, cond_new) && same_node(true_old, true_new) &&
same_node(false_old, false_new)) {
return Stmt(v);
}
return Cond::make(cond_new, true_new, false_new);
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,68 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch {
namespace jit {
namespace tensorexpr {
class Add;
class Sub;
class Mul;
class Div;
class Mod;
class Max;
class Min;
class CompareSelect;
class IntImm;
class FloatImm;
class Cast;
class Variable;
class Let;
class Ramp;
class Load;
class For;
class Block;
class Store;
class Broadcast;
class IfThenElse;
class Expr;
class Stmt;
class BaseCallNode;
class FunctionCall;
class Allocate;
class Free;
class Cond;
class TORCH_API IRMutator {
public:
virtual ~IRMutator() {}
virtual Expr mutate(const Add* v);
virtual Expr mutate(const Sub* v);
virtual Expr mutate(const Mul* v);
virtual Expr mutate(const Div* v);
virtual Expr mutate(const Mod* v);
virtual Expr mutate(const Max* v);
virtual Expr mutate(const Min* v);
virtual Expr mutate(const CompareSelect* v);
virtual Expr mutate(const IntImm* v);
virtual Expr mutate(const FloatImm* v);
virtual Expr mutate(const Cast* v);
virtual Expr mutate(const Variable* v);
virtual Expr mutate(const Let* v);
virtual Expr mutate(const Ramp* v);
virtual Expr mutate(const Load* v);
virtual Expr mutate(const Broadcast* v);
virtual Expr mutate(const IfThenElse* v);
virtual Stmt mutate(const For* v);
virtual Stmt mutate(const Block* v);
virtual Stmt mutate(const Store* v);
virtual Stmt mutate(const Allocate* v);
virtual Stmt mutate(const Free* v);
virtual Stmt mutate(const Cond* v);
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,126 @@
#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
namespace torch {
namespace jit {
namespace tensorexpr {
template <typename Op>
static void visit_binary_op(const BinaryOpNode<Op>* v, IRVisitor* visitor) {
v->lhs().accept(visitor);
v->rhs().accept(visitor);
}
void IRVisitor::visit(const Add* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Sub* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Mul* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Div* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Mod* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Max* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const Min* v) {
visit_binary_op(v, this);
}
void IRVisitor::visit(const CompareSelect* v) {
v->lhs().accept(this);
v->rhs().accept(this);
}
void IRVisitor::visit(const IntImm* v) {}
void IRVisitor::visit(const FloatImm* v) {}
void IRVisitor::visit(const Cast* v) {
v->src_value().accept(this);
}
void IRVisitor::visit(const Variable* v) {}
void IRVisitor::visit(const Let* v) {
v->var().accept(this);
v->value().accept(this);
v->body().accept(this);
}
void IRVisitor::visit(const Ramp* v) {
v->base().accept(this);
v->stride().accept(this);
}
void IRVisitor::visit(const Load* v) {
v->base_handle().accept(this);
v->index().accept(this);
v->mask().accept(this);
}
void IRVisitor::visit(const Store* v) {
v->base_handle().accept(this);
v->index().accept(this);
v->value().accept(this);
v->mask().accept(this);
}
void IRVisitor::visit(const Block* v) {
for (int i = 0; i < v->nstmts(); i++) {
v->stmt(i).accept(this);
}
}
void IRVisitor::visit(const For* v) {
v->var().accept(this);
v->start().accept(this);
v->stop().accept(this);
v->body().accept(this);
}
void IRVisitor::visit(const Broadcast* v) {
v->value().accept(this);
}
void IRVisitor::visit(const IfThenElse* v) {
v->condition().accept(this);
v->true_value().accept(this);
v->false_value().accept(this);
}
void IRVisitor::visit(const Allocate* v) {
Var buffer_var = v->buffer_var();
buffer_var.accept(this);
std::vector<Expr> dims = v->dims();
for (Expr& dim : dims) {
dim.accept(this);
}
}
void IRVisitor::visit(const Free* v) {
Var buffer_var = v->buffer_var();
buffer_var.accept(this);
}
void IRVisitor::visit(const Cond* v) {
Expr condition = v->condition();
Stmt true_stmt = v->true_stmt();
Stmt false_stmt = v->false_stmt();
condition.accept(this);
true_stmt.accept(this);
false_stmt.accept(this);
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,71 @@
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch {
namespace jit {
namespace tensorexpr {
class Add;
class Sub;
class Mul;
class Div;
class Mod;
class Max;
class Min;
class CompareSelect;
class IntImm;
class FloatImm;
class Cast;
class Variable;
class Let;
class Ramp;
class Load;
class For;
class Block;
class Store;
class Broadcast;
class IfThenElse;
class BaseCallNode;
class FunctionCall;
class Allocate;
class Free;
class Cond;
class TORCH_API IRVisitor {
public:
virtual ~IRVisitor() {}
virtual void visit(const Add* v);
virtual void visit(const Sub* v);
virtual void visit(const Mul* v);
virtual void visit(const Div* v);
virtual void visit(const Mod* v);
virtual void visit(const Max* v);
virtual void visit(const Min* v);
virtual void visit(const CompareSelect* v);
virtual void visit(const IntImm* v);
virtual void visit(const FloatImm* v);
virtual void visit(const Cast* v);
virtual void visit(const Variable* v);
virtual void visit(const Let* v);
virtual void visit(const Ramp* v);
virtual void visit(const Load* v);
virtual void visit(const For* v);
virtual void visit(const Block* v);
virtual void visit(const Store* v);
virtual void visit(const Broadcast* v);
virtual void visit(const IfThenElse* v);
// BaseCallNode is the base class for all call nodes.
// For any visitors that only needs the common behavior, only override this
// function is enough. This is because all derived class handlers will call
// this function by default.
// Override the derived class handler only if the logic is more specific to
// that.
virtual void visit(const Allocate* v);
virtual void visit(const Free* v);
virtual void visit(const Cond* v);
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch