mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Add IR visitor, IR mutator, and IR evaluator. (#33219)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33219 Test Plan: Imported from OSS Differential Revision: D19848381 Pulled By: ZolotukhinM fbshipit-source-id: 44ca7cd99c25e290a8ffd8146785c19f9c785dfd
This commit is contained in:
committed by
Facebook Github Bot
parent
49af9425a7
commit
fc70fc3610
@ -456,9 +456,13 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/vararg_functions.cpp
|
||||
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/codegen.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/expr.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/eval.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/types.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/unique_name_manager.cpp
|
||||
)
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "test/cpp/tensorexpr/test_utils.h"
|
||||
#include "torch/csrc/jit/tensorexpr/buffer.h"
|
||||
#include "torch/csrc/jit/tensorexpr/eval.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
|
||||
#include <cmath>
|
||||
@ -14,6 +15,53 @@ namespace torch {
|
||||
namespace jit {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
|
||||
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
|
||||
|
||||
void testExprBasicValueTest() {
|
||||
KernelScope kernel_scope;
|
||||
Expr a = IntImm::make(2), b = IntImm::make(3);
|
||||
Expr c = Add::make(a, b);
|
||||
SimpleIRExprEval eval(c);
|
||||
EXPECT_EQ(eval.value<int>(), 5);
|
||||
}
|
||||
|
||||
void testExprBasicValueTest02() {
|
||||
KernelScope kernel_scope;
|
||||
Expr a(2.0f);
|
||||
Expr b(3.0f);
|
||||
Expr c(4.0f);
|
||||
Expr d(5.0f);
|
||||
Expr f = (a + b) - (c + d);
|
||||
SimpleIRExprEval eval(f);
|
||||
EXPECT_EQ(eval.value<float>(), -4.0f);
|
||||
}
|
||||
|
||||
void testExprLetTest01() {
|
||||
KernelScope kernel_scope;
|
||||
Var x("x", kFloat32);
|
||||
Expr value = Expr(3.f);
|
||||
Expr body = Expr(2.f) + (x * Expr(3.f) + Expr(4.f));
|
||||
Expr result = Let::make(x, Expr(3.f), body);
|
||||
SimpleIRExprEval eval(result);
|
||||
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4));
|
||||
}
|
||||
|
||||
void testExprLetTest02() {
|
||||
KernelScope kernel_scope;
|
||||
Var x("x", kFloat32);
|
||||
Var y("y", kFloat32);
|
||||
Expr value = Expr(3.f);
|
||||
Expr body = Expr(2.f) + (x * Expr(3.f) + Expr(4.f) * y);
|
||||
Expr e1 = Let::make(x, Expr(3.f), body);
|
||||
Expr e2 = Let::make(y, Expr(6.f), e1);
|
||||
SimpleIRExprEval eval(e2);
|
||||
EXPECT_EQ(eval.value<float>(), 2 + (3 * 3 + 4 * 6));
|
||||
}
|
||||
|
||||
static Expr test_01(const Expr& expr) {
|
||||
return expr;
|
||||
}
|
||||
|
||||
void testExprVectorAdd01() {
|
||||
KernelScope kernel_scope;
|
||||
const int kVectorSize = 8;
|
||||
@ -54,5 +102,63 @@ void testExprVectorAdd01() {
|
||||
EXPECT_EQ(value.dtype(), Dtype(kFloat32, kVectorSize));
|
||||
}
|
||||
|
||||
void testExprCompareSelectEQ() {
|
||||
KernelScope kernel_scope;
|
||||
constexpr int N = 1024;
|
||||
Buffer a(Var("A", kHandle), kInt32, {N});
|
||||
Buffer b(Var("B", kHandle), kInt32, {N});
|
||||
Buffer c(Var("C", kHandle), kInt32, {N});
|
||||
std::vector<int> a_buffer(N, 1);
|
||||
std::vector<int> b_buffer(N, 1);
|
||||
std::vector<int> c_buffer(N, 0);
|
||||
std::vector<int> c_ref(N, 0);
|
||||
|
||||
auto mask = IntImm::make(1);
|
||||
Var i("i", kInt32);
|
||||
auto memcpy_expr = For::make(
|
||||
i,
|
||||
0,
|
||||
N,
|
||||
Store::make(
|
||||
c,
|
||||
i,
|
||||
CompareSelect::make(
|
||||
Load::make(a, i, mask),
|
||||
Load::make(b, i, mask),
|
||||
CompareSelectOperation::kEQ),
|
||||
mask));
|
||||
|
||||
SimpleIREvaluator ir_eval(memcpy_expr, a, b, c);
|
||||
ir_eval(a_buffer, b_buffer, c_buffer);
|
||||
|
||||
ASSERT_EQ(a_buffer.size(), N);
|
||||
ASSERT_EQ(b_buffer.size(), N);
|
||||
ASSERT_EQ(c_buffer.size(), N);
|
||||
|
||||
assertAllEqual(a_buffer, 1);
|
||||
assertAllEqual(b_buffer, 1);
|
||||
assertAllEqual(c_buffer, 1);
|
||||
}
|
||||
|
||||
void testExprDynamicShapeAdd() {
|
||||
KernelScope kernel_scope;
|
||||
auto testWithSize = [](int32_t size) {
|
||||
Var n("n", kInt32);
|
||||
Buffer a(Var("a", kHandle), kFloat32, {n});
|
||||
Buffer b(Var("b", kHandle), kFloat32, {n});
|
||||
Buffer c(Var("c", kHandle), kFloat32, {n});
|
||||
Var i("i", kInt32);
|
||||
Stmt s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
|
||||
std::vector<float> aData(size, 1.0f);
|
||||
std::vector<float> bData(size, 2.0f);
|
||||
std::vector<float> cData(size, 0.0f);
|
||||
SimpleIREvaluator(s, a, b, c, n)(aData, bData, cData, size);
|
||||
ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
|
||||
};
|
||||
testWithSize(1);
|
||||
testWithSize(16);
|
||||
testWithSize(37);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -9,7 +9,13 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
#define TH_FORALL_TESTS(_) \
|
||||
_(ExprBasicValueTest) \
|
||||
_(ExprBasicValueTest02) \
|
||||
_(ExprLetTest01) \
|
||||
_(ExprLetTest02) \
|
||||
_(ExprVectorAdd01) \
|
||||
_(ExprCompareSelectEQ) \
|
||||
_(ExprDynamicShapeAdd) \
|
||||
_(TypeTest01) \
|
||||
|
||||
#define TH_FORALL_TESTS_CUDA(_) \
|
||||
|
@ -190,16 +190,24 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/mobile/register_mobile_ops.cpp",
|
||||
"torch/csrc/jit/mobile/interpreter.cpp",
|
||||
"torch/csrc/jit/mobile/type_parser.cpp",
|
||||
"torch/csrc/jit/tensorexpr/codegen.cpp",
|
||||
"torch/csrc/jit/tensorexpr/eval.cpp",
|
||||
"torch/csrc/jit/tensorexpr/expr.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_mutator.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_visitor.cpp",
|
||||
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
|
||||
"torch/csrc/jit/tensorexpr/types.cpp",
|
||||
"torch/csrc/jit/tensorexpr/unique_name_manager.cpp",
|
||||
"torch/csrc/utils/byte_order.cpp",
|
||||
"torch/csrc/utils/tensor_flatten.cpp",
|
||||
"torch/csrc/utils/variadic.cpp",
|
||||
"torch/csrc/jit/tensorexpr/codegen.cpp",
|
||||
"torch/csrc/jit/tensorexpr/eval.cpp",
|
||||
"torch/csrc/jit/tensorexpr/expr.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_mutator.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_visitor.cpp",
|
||||
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
|
||||
"torch/csrc/jit/tensorexpr/types.cpp",
|
||||
"torch/csrc/jit/tensorexpr/unique_name_manager.cpp",
|
||||
|
51
torch/csrc/jit/tensorexpr/codegen.cpp
Normal file
51
torch/csrc/jit/tensorexpr/codegen.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include "torch/csrc/jit/tensorexpr/codegen.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList::
|
||||
FindStmtFactoryMethod(const std::string& name) {
|
||||
auto iter = stmt_factory_methods_.find(name);
|
||||
if (iter == stmt_factory_methods_.end()) {
|
||||
std::ostringstream oss;
|
||||
oss << "Invalid stmt codegen name: " << name << ". ";
|
||||
oss << "Existing codegen names: [";
|
||||
int index = 0;
|
||||
for (const auto& entry : stmt_factory_methods_) {
|
||||
if (index != 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << entry.first;
|
||||
index++;
|
||||
}
|
||||
oss << "]";
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void RegisterCodeGenList::AddStmtFactoryMethod(
|
||||
const std::string& name,
|
||||
const StmtFactoryMethod& stmt_factory_method) {
|
||||
auto insert_ret =
|
||||
stmt_factory_methods_.insert(std::make_pair(name, stmt_factory_method));
|
||||
if (!insert_ret.second) {
|
||||
throw std::runtime_error("Duplicated CodeGen names: " + name);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<CodeGen> CreateCodeGen(
|
||||
const std::string& name,
|
||||
const Stmt& stmt,
|
||||
const std::vector<CodeGen::BufferArg>& params) {
|
||||
RegisterCodeGenList::StmtFactoryMethod method =
|
||||
RegisterCodeGenList::GetInstance().FindStmtFactoryMethod(name);
|
||||
return method(stmt, params);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
157
torch/csrc/jit/tensorexpr/codegen.h
Normal file
157
torch/csrc/jit/tensorexpr/codegen.h
Normal file
@ -0,0 +1,157 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/jit/tensorexpr/buffer.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
class CodeGen {
|
||||
public:
|
||||
class BufferArg;
|
||||
class CallArg;
|
||||
|
||||
template <typename... Ts>
|
||||
CodeGen(const Stmt& stmt, Ts... ts)
|
||||
: stmt_(stmt), buffer_args_({BufferArg(ts)...}) {}
|
||||
|
||||
CodeGen(const Stmt& stmt, const std::vector<BufferArg>& buffer_args)
|
||||
: stmt_(stmt), buffer_args_(buffer_args) {}
|
||||
|
||||
virtual ~CodeGen() {}
|
||||
|
||||
const Stmt& stmt() const {
|
||||
return stmt_;
|
||||
}
|
||||
|
||||
std::vector<BufferArg>& buffer_args() {
|
||||
return buffer_args_;
|
||||
}
|
||||
|
||||
const std::vector<BufferArg>& buffer_args() const {
|
||||
return buffer_args_;
|
||||
}
|
||||
|
||||
TORCH_API virtual void call(const std::vector<CallArg>& args) {
|
||||
LOG(FATAL) << "unimplemented call";
|
||||
}
|
||||
|
||||
private:
|
||||
Stmt stmt_;
|
||||
std::vector<BufferArg> buffer_args_;
|
||||
};
|
||||
|
||||
class CodeGen::BufferArg {
|
||||
public:
|
||||
BufferArg(const Buffer& buffer)
|
||||
: var_(buffer.data()), dtype_(buffer.dtype()) {}
|
||||
BufferArg(const Var& var) : var_(var), dtype_(var.dtype()), isVar_(true) {}
|
||||
|
||||
const Var& var() const {
|
||||
return var_;
|
||||
}
|
||||
Var& var() {
|
||||
return var_;
|
||||
}
|
||||
Dtype dtype() const {
|
||||
return dtype_;
|
||||
}
|
||||
|
||||
bool isVar() const {
|
||||
return isVar_;
|
||||
}
|
||||
|
||||
private:
|
||||
Var var_;
|
||||
Dtype dtype_;
|
||||
bool isVar_{false};
|
||||
};
|
||||
|
||||
class CodeGen::CallArg {
|
||||
public:
|
||||
template <typename T>
|
||||
CallArg(const std::vector<T>& buffer) : ptr_(const_cast<T*>(buffer.data())) {}
|
||||
|
||||
CallArg(void* ptr) : ptr_(ptr) {}
|
||||
|
||||
CallArg(int32_t i) : ival_(i) {}
|
||||
|
||||
CallArg(float f) : fval_(f) {}
|
||||
|
||||
void* data() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
int32_t intData() const {
|
||||
return ival_;
|
||||
}
|
||||
|
||||
float floatData() const {
|
||||
return fval_;
|
||||
}
|
||||
|
||||
int* intPtr() const {
|
||||
return const_cast<int*>(&ival_);
|
||||
}
|
||||
|
||||
float* floatPtr() const {
|
||||
return const_cast<float*>(&fval_);
|
||||
}
|
||||
|
||||
private:
|
||||
union {
|
||||
void* ptr_;
|
||||
float fval_;
|
||||
int32_t ival_;
|
||||
};
|
||||
};
|
||||
|
||||
class RegisterCodeGenList {
|
||||
public:
|
||||
TORCH_API static RegisterCodeGenList& GetInstance() {
|
||||
static RegisterCodeGenList codegen_list;
|
||||
return codegen_list;
|
||||
}
|
||||
|
||||
using StmtFactoryMethod = std::function<std::unique_ptr<CodeGen>(
|
||||
const Stmt& stmt,
|
||||
const std::vector<CodeGen::BufferArg>&)>;
|
||||
|
||||
TORCH_API StmtFactoryMethod FindStmtFactoryMethod(const std::string& name);
|
||||
|
||||
private:
|
||||
template <class CodeGenType>
|
||||
friend class RegisterCodeGen;
|
||||
RegisterCodeGenList() {}
|
||||
TORCH_API void AddStmtFactoryMethod(
|
||||
const std::string& name,
|
||||
const StmtFactoryMethod& stmt_factory_method);
|
||||
RegisterCodeGenList(const RegisterCodeGenList&) = delete;
|
||||
RegisterCodeGenList& operator=(const RegisterCodeGenList&) = delete;
|
||||
|
||||
std::unordered_map<std::string, StmtFactoryMethod> stmt_factory_methods_;
|
||||
};
|
||||
|
||||
template <class CodeGenType>
|
||||
class RegisterCodeGen {
|
||||
public:
|
||||
explicit RegisterCodeGen(const std::string& name) {
|
||||
RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance();
|
||||
codegen_list.AddStmtFactoryMethod(
|
||||
name,
|
||||
[](const Stmt& stmt, const std::vector<CodeGen::BufferArg>& params) {
|
||||
std::unique_ptr<CodeGen> method(new CodeGenType(stmt, params));
|
||||
return method;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_API std::unique_ptr<CodeGen> CreateCodeGen(
|
||||
const std::string& name,
|
||||
const Stmt& stmt,
|
||||
const std::vector<CodeGen::BufferArg>& params);
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
11
torch/csrc/jit/tensorexpr/eval.cpp
Normal file
11
torch/csrc/jit/tensorexpr/eval.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
#include "torch/csrc/jit/tensorexpr/eval.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
RegisterCodeGen<SimpleIREvaluator> reg("simple_ir_eval");
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
606
torch/csrc/jit/tensorexpr/eval.h
Normal file
606
torch/csrc/jit/tensorexpr/eval.h
Normal file
@ -0,0 +1,606 @@
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/util/Logging.h>
|
||||
#include "torch/csrc/jit/tensorexpr/buffer.h"
|
||||
#include "torch/csrc/jit/tensorexpr/codegen.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
#include "torch/csrc/jit/tensorexpr/types.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
class Value {
|
||||
public:
|
||||
Value() : dtype_(kInt32) {
|
||||
i32_values.push_back(0);
|
||||
}
|
||||
Value(int v) : dtype_(kInt32) {
|
||||
i32_values.push_back(v);
|
||||
}
|
||||
Value(float v) : dtype_(kFloat32) {
|
||||
f32_values.push_back(v);
|
||||
}
|
||||
Value(const std::vector<int>& v)
|
||||
: dtype_(Dtype(kInt32, v.size())), i32_values(v) {}
|
||||
Value(const std::vector<float>& v)
|
||||
: dtype_(Dtype(kFloat32, v.size())), f32_values(v) {}
|
||||
|
||||
template <typename T>
|
||||
T as() const;
|
||||
|
||||
template <typename T>
|
||||
const std::vector<T>& as_vec() const;
|
||||
|
||||
Dtype dtype() const {
|
||||
return dtype_;
|
||||
}
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
std::vector<int32> i32_values;
|
||||
std::vector<float> f32_values;
|
||||
void* ptr;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline int Value::as<int>() const {
|
||||
CHECK_EQ(dtype_, kInt32) << "invalid dtype";
|
||||
return i32_values[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline float Value::as<float>() const {
|
||||
CHECK_EQ(dtype_, kFloat32) << "invalid dtype";
|
||||
return f32_values[0];
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const std::vector<float>& Value::as_vec<float>() const {
|
||||
CHECK_EQ(dtype_.scalar_type(), kFloat32) << "invalid dtype";
|
||||
return f32_values;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline const std::vector<int>& Value::as_vec<int>() const {
|
||||
CHECK_EQ(dtype_.scalar_type(), kInt32) << "invalid dtype";
|
||||
return i32_values;
|
||||
}
|
||||
|
||||
inline int mod_value(int lhs, int rhs) {
|
||||
return lhs % rhs;
|
||||
}
|
||||
|
||||
inline float mod_value(float lhs, float rhs) {
|
||||
return std::fmod(lhs, rhs);
|
||||
}
|
||||
|
||||
class SimpleIREvaluator : public CodeGen, public IRVisitor {
|
||||
public:
|
||||
using CodeGen::CodeGen;
|
||||
|
||||
~SimpleIREvaluator() override {}
|
||||
|
||||
TORCH_API void call(const std::vector<CallArg>& args) override {
|
||||
CHECK_EQ(args.size(), buffer_args().size());
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
bind(buffer_args()[i], args[i]);
|
||||
}
|
||||
stmt().accept(this);
|
||||
eval_context_.clear();
|
||||
buffer_mapping_.clear();
|
||||
internal_buffers_.clear();
|
||||
}
|
||||
|
||||
void bind(const BufferArg& buf, const CallArg& data) {
|
||||
if (buf.isVar()) {
|
||||
if (buf.dtype() == kInt32) {
|
||||
eval_context_[buf.var().node()] = data.intData();
|
||||
} else if (buf.dtype() == kFloat32) {
|
||||
eval_context_[buf.var().node()] = data.floatData();
|
||||
} else {
|
||||
LOG(FATAL) << "Unhandled dtype for argument " << buf.var().name_hint()
|
||||
<< ": " << buf.dtype();
|
||||
}
|
||||
} else {
|
||||
buffer_mapping_[buf.var().node()] = data.data();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
void operator()(const Ts&... ts) {
|
||||
std::vector<CallArg> args({CallArg(ts)...});
|
||||
call(args);
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Add* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Sub* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Mul* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Div* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Mod* v) override {
|
||||
visit_binary_op(v);
|
||||
}
|
||||
TORCH_API void visit(const Max* v) override {
|
||||
visit_binary_op(v, v->propagate_nans());
|
||||
}
|
||||
TORCH_API void visit(const Min* v) override {
|
||||
visit_binary_op(v, v->propagate_nans());
|
||||
}
|
||||
|
||||
void visit(const CompareSelect* v) override {
|
||||
visit_compare_select_op(v, v->compare_select_op());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Value binary_op(
|
||||
const Value& lhs,
|
||||
const Value& rhs,
|
||||
IRNodeType op_type,
|
||||
bool option = false) {
|
||||
std::vector<T> lhs_v = lhs.as_vec<T>();
|
||||
std::vector<T> rhs_v = rhs.as_vec<T>();
|
||||
std::vector<T> result_v(lhs_v.size());
|
||||
for (size_t i = 0; i < lhs_v.size(); i++) {
|
||||
switch (op_type) {
|
||||
case IRNodeType::kAdd:
|
||||
result_v[i] = lhs_v[i] + rhs_v[i];
|
||||
break;
|
||||
case IRNodeType::kSub:
|
||||
result_v[i] = lhs_v[i] - rhs_v[i];
|
||||
break;
|
||||
case IRNodeType::kMul:
|
||||
result_v[i] = lhs_v[i] * rhs_v[i];
|
||||
break;
|
||||
case IRNodeType::kDiv:
|
||||
result_v[i] = lhs_v[i] / rhs_v[i];
|
||||
break;
|
||||
case IRNodeType::kMod:
|
||||
result_v[i] = mod_value(lhs_v[i], rhs_v[i]);
|
||||
break;
|
||||
case IRNodeType::kMax:
|
||||
if (lhs.dtype() == kFloat32 && rhs.dtype() == kFloat32 && option) {
|
||||
// Propagate NaNs
|
||||
if (std::isnan((float)lhs_v[i])) {
|
||||
result_v[i] = lhs_v[i];
|
||||
} else if (std::isnan((float)rhs_v[i])) {
|
||||
result_v[i] = rhs_v[i];
|
||||
}
|
||||
} else {
|
||||
result_v[i] = lhs_v[i] > rhs_v[i] ? lhs_v[i] : rhs_v[i];
|
||||
}
|
||||
break;
|
||||
case IRNodeType::kMin:
|
||||
if (lhs.dtype() == kFloat32 && rhs.dtype() == kFloat32 && option) {
|
||||
// Propagate NaNs
|
||||
if (std::isnan((float)lhs_v[i])) {
|
||||
result_v[i] = lhs_v[i];
|
||||
} else if (std::isnan((float)rhs_v[i])) {
|
||||
result_v[i] = rhs_v[i];
|
||||
}
|
||||
} else {
|
||||
result_v[i] = lhs_v[i] < rhs_v[i] ? lhs_v[i] : rhs_v[i];
|
||||
}
|
||||
break;
|
||||
default:
|
||||
// TODO: change to a proper error report
|
||||
throw std::runtime_error("invalid operator type");
|
||||
}
|
||||
}
|
||||
return Value(result_v);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Value compare_select_op(
|
||||
const Value& lhs,
|
||||
const Value& rhs,
|
||||
CompareSelectOperation cmp_op) {
|
||||
std::vector<T> lhs_v = lhs.as_vec<T>();
|
||||
std::vector<T> rhs_v = rhs.as_vec<T>();
|
||||
std::vector<int> result_v(lhs_v.size());
|
||||
for (size_t i = 0; i < lhs_v.size(); i++) {
|
||||
switch (cmp_op) {
|
||||
case CompareSelectOperation::kEQ:
|
||||
result_v[i] = (lhs_v[i] == rhs_v[i]) ? 1 : 0;
|
||||
break;
|
||||
case CompareSelectOperation::kNE:
|
||||
result_v[i] = (lhs_v[i] != rhs_v[i]) ? 1 : 0;
|
||||
break;
|
||||
case CompareSelectOperation::kGT:
|
||||
result_v[i] = (lhs_v[i] > rhs_v[i]) ? 1 : 0;
|
||||
break;
|
||||
case CompareSelectOperation::kGE:
|
||||
result_v[i] = (lhs_v[i] >= rhs_v[i]) ? 1 : 0;
|
||||
break;
|
||||
case CompareSelectOperation::kLT:
|
||||
result_v[i] = (lhs_v[i] < rhs_v[i]) ? 1 : 0;
|
||||
break;
|
||||
case CompareSelectOperation::kLE:
|
||||
result_v[i] = (lhs_v[i] <= rhs_v[i]) ? 1 : 0;
|
||||
break;
|
||||
default:
|
||||
// TODO: change to a proper error report
|
||||
throw std::runtime_error("invalid operator type");
|
||||
}
|
||||
}
|
||||
return Value(result_v);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void visit_binary_op(const BinaryOpNode<Op>* v, bool option = false) {
|
||||
v->lhs().accept(this);
|
||||
Value lhs_v = value_;
|
||||
v->rhs().accept(this);
|
||||
Value rhs_v = value_;
|
||||
CHECK_EQ(lhs_v.dtype(), rhs_v.dtype());
|
||||
IRNodeType expr_type = v->expr_type();
|
||||
if (lhs_v.dtype().scalar_type() == kFloat32) {
|
||||
value_ = binary_op<float>(lhs_v, rhs_v, expr_type);
|
||||
} else if (lhs_v.dtype().scalar_type() == kInt32) {
|
||||
value_ = binary_op<int>(lhs_v, rhs_v, expr_type);
|
||||
} else {
|
||||
LOG(FATAL) << "invalid dtype: " << lhs_v.dtype();
|
||||
}
|
||||
}
|
||||
|
||||
void visit_compare_select_op(
|
||||
const CompareSelect* v,
|
||||
CompareSelectOperation cmp_op) {
|
||||
v->lhs().accept(this);
|
||||
Value lhs_v = value_;
|
||||
v->rhs().accept(this);
|
||||
Value rhs_v = value_;
|
||||
CHECK_EQ(lhs_v.dtype(), rhs_v.dtype());
|
||||
if (lhs_v.dtype().scalar_type() == kFloat32) {
|
||||
value_ = compare_select_op<float>(lhs_v, rhs_v, cmp_op);
|
||||
} else if (lhs_v.dtype().scalar_type() == kInt32) {
|
||||
value_ = compare_select_op<int>(lhs_v, rhs_v, cmp_op);
|
||||
} else {
|
||||
LOG(FATAL) << "invalid dtype: " << lhs_v.dtype();
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const IntImm* v) override {
|
||||
value_ = Value(v->value());
|
||||
}
|
||||
TORCH_API void visit(const FloatImm* v) override {
|
||||
value_ = Value(v->value());
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Let* v) override {
|
||||
const Variable* var = v->var().AsNode<Variable>();
|
||||
CHECK(var != nullptr);
|
||||
v->value().accept(this);
|
||||
Value value = value_;
|
||||
auto iter = eval_context_.find(var);
|
||||
// TODO: make the same value settable multiple times.
|
||||
CHECK(iter == eval_context_.end())
|
||||
<< "var must not exist in the context before";
|
||||
eval_context_[var] = value_;
|
||||
|
||||
v->body().accept(this);
|
||||
|
||||
eval_context_.erase(var);
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Variable* v) override {
|
||||
auto iter = eval_context_.find(v);
|
||||
CHECK(iter != eval_context_.end())
|
||||
<< "var must be defined in the context before";
|
||||
value_ = iter->second;
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Cast* v) override {
|
||||
const Expr& src_value = v->src_value();
|
||||
src_value.accept(this);
|
||||
Dtype dst_dtype = v->dtype();
|
||||
Dtype src_dtype = src_value.dtype();
|
||||
CHECK_EQ(src_dtype.lanes(), dst_dtype.lanes());
|
||||
if (src_dtype != dst_dtype) {
|
||||
if (src_dtype == kFloat32 && dst_dtype == kInt32) {
|
||||
const std::vector<float>& src_values = value_.as_vec<float>();
|
||||
std::vector<int> dst_values(src_values.size());
|
||||
for (int i = 0; i < src_dtype.lanes(); ++i) {
|
||||
dst_values[i] = static_cast<int>(src_values[i]);
|
||||
}
|
||||
this->value_ = Value(dst_values);
|
||||
} else if (src_dtype == kInt32 && dst_dtype == kFloat32) {
|
||||
const std::vector<int>& src_values = value_.as_vec<int>();
|
||||
std::vector<float> dst_values(src_values.size());
|
||||
for (int i = 0; i < src_dtype.lanes(); ++i) {
|
||||
dst_values[i] = static_cast<float>(src_values[i]);
|
||||
}
|
||||
this->value_ = Value(dst_values);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const For* v) override {
|
||||
const BaseExprNode* var_node = v->var().node();
|
||||
v->start().accept(this);
|
||||
int start = value_.as<int>();
|
||||
v->stop().accept(this);
|
||||
int stop = value_.as<int>();
|
||||
auto iter = eval_context_.find(var_node);
|
||||
CHECK(iter == eval_context_.end())
|
||||
<< "var in For must not exist in eval context";
|
||||
for (int i = start; i < stop; i++) {
|
||||
eval_context_[var_node] = Value(i);
|
||||
v->body().accept(this);
|
||||
}
|
||||
eval_context_.erase(var_node);
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Ramp* v) override {
|
||||
v->base().accept(this);
|
||||
int base = value().as<int>();
|
||||
v->stride().accept(this);
|
||||
int stride = value().as<int>();
|
||||
int lanes = v->lanes();
|
||||
|
||||
std::vector<int> values(lanes);
|
||||
for (int i = 0; i < lanes; i++) {
|
||||
values[i] = base + i * stride;
|
||||
}
|
||||
|
||||
value_ = Value(values);
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Broadcast* v) override {
|
||||
v->value().accept(this);
|
||||
Value value = this->value();
|
||||
int lanes = v->lanes();
|
||||
if (value.dtype() == kInt32) {
|
||||
std::vector<int> v(lanes, value.as<int>());
|
||||
value_ = Value(v);
|
||||
} else if (value.dtype() == kFloat32) {
|
||||
std::vector<float> v(lanes, value.as<float>());
|
||||
value_ = Value(v);
|
||||
} else {
|
||||
LOG(FATAL) << "invalid dtype: " << value.dtype();
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const IfThenElse* v) override {
|
||||
v->condition().accept(this);
|
||||
if (value_.as<int>()) {
|
||||
v->true_value().accept(this);
|
||||
} else {
|
||||
v->false_value().accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Load* v) override {
|
||||
const Variable* base_node = v->base_handle().node();
|
||||
auto iter = buffer_mapping_.find(base_node);
|
||||
CHECK(iter != buffer_mapping_.end())
|
||||
<< "missing buffer binding: " << base_node->name_hint();
|
||||
void* ptr = iter->second;
|
||||
|
||||
v->index().accept(this);
|
||||
std::vector<int> index = value().as_vec<int>();
|
||||
v->mask().accept(this);
|
||||
std::vector<int> mask = value().as_vec<int>();
|
||||
Dtype v_sdtype = v->dtype().scalar_type();
|
||||
if (v_sdtype == kFloat32) {
|
||||
float* ptr_f = static_cast<float*>(ptr);
|
||||
std::vector<float> v(index.size());
|
||||
for (size_t i = 0; i < index.size(); i++) {
|
||||
if (mask[i]) {
|
||||
v[i] = ptr_f[index[i]];
|
||||
}
|
||||
}
|
||||
value_ = Value(v);
|
||||
} else if (v_sdtype == kInt32) {
|
||||
int* ptr_i = static_cast<int*>(ptr);
|
||||
std::vector<int> v(index.size());
|
||||
for (size_t i = 0; i < index.size(); i++) {
|
||||
if (mask[i]) {
|
||||
v[i] = ptr_i[index[i]];
|
||||
}
|
||||
}
|
||||
value_ = Value(v);
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid dtype: " << v_sdtype;
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_API void visit(const Store* v) override {
|
||||
const Variable* base_node = v->base_handle().node();
|
||||
auto iter = buffer_mapping_.find(base_node);
|
||||
CHECK(iter != buffer_mapping_.end());
|
||||
void* ptr = iter->second;
|
||||
|
||||
v->index().accept(this);
|
||||
std::vector<int> index = value().as_vec<int>();
|
||||
v->mask().accept(this);
|
||||
std::vector<int> mask = value().as_vec<int>();
|
||||
CHECK_EQ(index.size(), mask.size());
|
||||
Dtype v_sdtype = v->value().dtype().scalar_type();
|
||||
if (v_sdtype == kFloat32) {
|
||||
v->value().accept(this);
|
||||
std::vector<float> value = this->value().as_vec<float>();
|
||||
CHECK_EQ(index.size(), value.size());
|
||||
float* ptr_f = static_cast<float*>(ptr);
|
||||
for (size_t i = 0; i < index.size(); i++) {
|
||||
if (mask[i]) {
|
||||
ptr_f[index[i]] = value[i];
|
||||
}
|
||||
}
|
||||
} else if (v_sdtype == kInt32) {
|
||||
v->value().accept(this);
|
||||
std::vector<int> value = this->value().as_vec<int>();
|
||||
CHECK_EQ(index.size(), value.size());
|
||||
int* ptr_i = static_cast<int*>(ptr);
|
||||
for (size_t i = 0; i < index.size(); i++) {
|
||||
if (mask[i]) {
|
||||
ptr_i[index[i]] = value[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid dtype: " << v_sdtype;
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const Allocate* v) override {
|
||||
const Variable* buffer_var = v->buffer_var().AsNode<Variable>();
|
||||
std::vector<Expr> dims = v->dims();
|
||||
int total_byte_size = v->dtype().byte_size();
|
||||
for (size_t i = 0; i < dims.size(); i++) {
|
||||
dims[i].accept(this);
|
||||
total_byte_size *= value_.as<int>();
|
||||
}
|
||||
int int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int);
|
||||
std::unique_ptr<std::vector<int>> buffer(new std::vector<int>(int_count));
|
||||
auto iter = buffer_mapping_.find(buffer_var);
|
||||
if (iter != buffer_mapping_.end() && iter->second != nullptr) {
|
||||
throw std::runtime_error(
|
||||
"Allocate a buffer that has already been allocated: " +
|
||||
buffer_var->name_hint());
|
||||
}
|
||||
buffer_mapping_[buffer_var] = buffer->data();
|
||||
internal_buffers_.insert(std::make_pair(buffer_var, std::move(buffer)));
|
||||
}
|
||||
|
||||
void visit(const Free* v) override {
|
||||
const Variable* buffer_var = v->buffer_var().AsNode<Variable>();
|
||||
int count = internal_buffers_.erase(buffer_var);
|
||||
if (count == 0) {
|
||||
throw std::runtime_error(
|
||||
"Free a buffer that is not currently bound: " +
|
||||
buffer_var->name_hint());
|
||||
}
|
||||
}
|
||||
|
||||
void visit(const Cond* v) override {
|
||||
v->condition().accept(this);
|
||||
if (value().as<int>()) {
|
||||
v->true_stmt().accept(this);
|
||||
} else {
|
||||
v->false_stmt().accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
Value value() const {
|
||||
return value_;
|
||||
}
|
||||
|
||||
private:
|
||||
Value value_;
|
||||
std::unordered_map<const BaseExprNode*, Value> eval_context_;
|
||||
std::unordered_map<const BaseExprNode*, void*> buffer_mapping_;
|
||||
std::unordered_map<const Variable*, std::unique_ptr<std::vector<int>>>
|
||||
internal_buffers_;
|
||||
};
|
||||
|
||||
using VarMapping = std::vector<std::pair<Expr, Expr>>;
|
||||
|
||||
class VarSubMutator : public IRMutator {
|
||||
public:
|
||||
VarSubMutator(const VarMapping& var_mapping) {
|
||||
for (const auto& entry : var_mapping) {
|
||||
const Expr& key = entry.first;
|
||||
const Expr& value = entry.second;
|
||||
const Variable* key_var = key.AsNode<Variable>();
|
||||
CHECK(key_var != nullptr);
|
||||
var_mapping_[key_var] = value;
|
||||
}
|
||||
}
|
||||
|
||||
Expr mutate(const Variable* var) override {
|
||||
auto iter = var_mapping_.find(var);
|
||||
if (iter == var_mapping_.end()) {
|
||||
return Expr(const_cast<Variable*>(var));
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<const Variable*, Expr> var_mapping_;
|
||||
};
|
||||
|
||||
template <class CodeGenType>
|
||||
class ExprEval {
|
||||
public:
|
||||
using BufferArg = CodeGen::BufferArg;
|
||||
using CallArg = CodeGen::CallArg;
|
||||
|
||||
template <typename... Ts>
|
||||
ExprEval(const Expr& expr, Ts... ts) : ExprEval(expr, {BufferArg(ts)...}) {}
|
||||
|
||||
ExprEval(const Expr& expr, const std::vector<BufferArg>& buffer_args)
|
||||
: dtype_(expr.dtype()) {
|
||||
std::vector<BufferArg> buffer_args_extended = buffer_args;
|
||||
Buffer ret_buf("ret_val", dtype_, {1});
|
||||
Stmt store_stmt = Store::make(ret_buf.data(), 0, expr);
|
||||
buffer_args_extended.push_back(ret_buf);
|
||||
codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended));
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
void operator()(Ts... ts) {
|
||||
call(ts...);
|
||||
}
|
||||
|
||||
void operator()(const std::vector<CallArg>& call_args) {
|
||||
call(call_args);
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
void call(Ts... ts) {
|
||||
call({CallArg(ts)...});
|
||||
}
|
||||
|
||||
void call(const std::vector<CallArg>& call_args) {
|
||||
std::vector<CallArg> call_args_extended = call_args;
|
||||
if (dtype_ == kFloat32) {
|
||||
std::vector<float> ret_val_arg(1);
|
||||
call_args_extended.push_back(CallArg(ret_val_arg));
|
||||
codegen_->call(call_args_extended);
|
||||
ret_value_ = Value(ret_val_arg[0]);
|
||||
} else if (dtype_ == kInt32) {
|
||||
std::vector<int> ret_val_arg(1);
|
||||
call_args_extended.push_back(CallArg(ret_val_arg));
|
||||
codegen_->call(call_args_extended);
|
||||
ret_value_ = Value(ret_val_arg[0]);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid dtype");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
T value(Ts... ts) {
|
||||
call(std::forward<Ts>(ts)...);
|
||||
return ret_value_.as<T>();
|
||||
}
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
std::unique_ptr<CodeGenType> codegen_;
|
||||
Value ret_value_;
|
||||
};
|
||||
|
||||
inline Expr Substitute(Expr* expr, const VarMapping& var_mapping) {
|
||||
VarSubMutator var_sub(var_mapping);
|
||||
return expr->accept_mutator(&var_sub);
|
||||
}
|
||||
|
||||
inline Stmt Substitute(Stmt* stmt, const VarMapping& var_mapping) {
|
||||
VarSubMutator var_sub(var_mapping);
|
||||
return stmt->accept_mutator(&var_sub);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -5,6 +5,8 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/jit/tensorexpr/ir_mutator.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
|
||||
#include "torch/csrc/jit/tensorexpr/types.h"
|
||||
#include "torch/csrc/jit/tensorexpr/mem_arena.h"
|
||||
|
||||
@ -20,6 +22,8 @@ class BaseExprNode : public KernelScopedObject {
|
||||
Dtype dtype() const {
|
||||
return dtype_;
|
||||
}
|
||||
TORCH_API virtual void accept(IRVisitor* visitor) const = 0;
|
||||
virtual Expr accept_mutator(IRMutator* mutator) = 0;
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
@ -29,6 +33,8 @@ class BaseExprNode : public KernelScopedObject {
|
||||
class BaseStmtNode : public KernelScopedObject {
|
||||
public:
|
||||
BaseStmtNode() {}
|
||||
TORCH_API virtual void accept(IRVisitor* visitor) const = 0;
|
||||
virtual Stmt accept_mutator(IRMutator* mutator) = 0;
|
||||
};
|
||||
|
||||
// A CRTP pattern to accept visitors for children class,
|
||||
@ -37,6 +43,10 @@ template <class Op, class Base = BaseExprNode>
|
||||
class ExprNode : public Base {
|
||||
public:
|
||||
using ExprNodeBase = ExprNode<Op>;
|
||||
void accept(IRVisitor* visitor) const override {
|
||||
visitor->visit(static_cast<const Op*>(this));
|
||||
}
|
||||
Expr accept_mutator(IRMutator* mutator) override;
|
||||
// pass the constructor to the base class
|
||||
using Base::Base;
|
||||
};
|
||||
@ -45,6 +55,10 @@ template <class Op>
|
||||
class StmtNode : public BaseStmtNode {
|
||||
public:
|
||||
using StmtNodeBase = StmtNode<Op>;
|
||||
void accept(IRVisitor* visitor) const override {
|
||||
visitor->visit(static_cast<const Op*>(this));
|
||||
}
|
||||
Stmt accept_mutator(IRMutator* mutator) override;
|
||||
StmtNode() {}
|
||||
};
|
||||
|
||||
@ -68,6 +82,23 @@ class TORCH_API Expr {
|
||||
return base_expr_node_ == nullptr;
|
||||
}
|
||||
|
||||
void accept(IRVisitor* visitor) const {
|
||||
// TODO: Consider implement this without using recursion. Otherwise,
|
||||
// if the expression tree is degenerate and too long, it could cause a
|
||||
// stack overflow.
|
||||
if (node() == nullptr) {
|
||||
return;
|
||||
}
|
||||
node()->accept(visitor);
|
||||
}
|
||||
|
||||
Expr accept_mutator(IRMutator* mutator) {
|
||||
if (node() == nullptr) {
|
||||
return Expr();
|
||||
}
|
||||
return node()->accept_mutator(mutator);
|
||||
}
|
||||
|
||||
Expr(int v);
|
||||
Expr(float v);
|
||||
|
||||
@ -115,6 +146,20 @@ class Stmt {
|
||||
return base_stmt_node_;
|
||||
}
|
||||
|
||||
void accept(IRVisitor* visitor) const {
|
||||
if (node() == nullptr) {
|
||||
return;
|
||||
}
|
||||
node()->accept(visitor);
|
||||
}
|
||||
|
||||
Stmt accept_mutator(IRMutator* mutator) {
|
||||
if (node() == nullptr) {
|
||||
return Stmt();
|
||||
}
|
||||
return node()->accept_mutator(mutator);
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return node() == nullptr;
|
||||
}
|
||||
@ -128,6 +173,18 @@ class Stmt {
|
||||
BaseStmtNode* base_stmt_node_ = nullptr;
|
||||
};
|
||||
|
||||
template <class Op, class Base>
|
||||
Expr ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
|
||||
ExprNode* this_mutable = const_cast<ExprNode*>(this);
|
||||
return mutator->mutate(static_cast<Op*>(this_mutable));
|
||||
}
|
||||
|
||||
template <class Op>
|
||||
Stmt StmtNode<Op>::accept_mutator(IRMutator* mutator) {
|
||||
StmtNode* this_mutable = const_cast<StmtNode*>(this);
|
||||
return mutator->mutate(static_cast<Op*>(this_mutable));
|
||||
}
|
||||
|
||||
inline bool same_node(const Expr& expr1, const Expr& expr2) {
|
||||
return expr1.AsNode<BaseExprNode>() == expr2.AsNode<BaseExprNode>();
|
||||
}
|
||||
|
272
torch/csrc/jit/tensorexpr/ir_mutator.cpp
Normal file
272
torch/csrc/jit/tensorexpr/ir_mutator.cpp
Normal file
@ -0,0 +1,272 @@
|
||||
#include "torch/csrc/jit/tensorexpr/ir_mutator.h"
|
||||
|
||||
#include "torch/csrc/jit/tensorexpr/eval.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
template <typename Op>
|
||||
static Expr mutate_binary_op(
|
||||
const BinaryOpNode<Op>* v,
|
||||
IRMutator* mutator,
|
||||
bool option = false) {
|
||||
Expr lhs = v->lhs();
|
||||
Expr rhs = v->rhs();
|
||||
Expr lhs_new = lhs.accept_mutator(mutator);
|
||||
Expr rhs_new = rhs.accept_mutator(mutator);
|
||||
if (same_node(lhs, lhs_new) && same_node(rhs, rhs_new)) {
|
||||
return Expr(v);
|
||||
}
|
||||
IRNodeType expr_type = v->expr_type();
|
||||
switch (expr_type) {
|
||||
case IRNodeType::kAdd:
|
||||
return Add::make(lhs_new, rhs_new);
|
||||
case IRNodeType::kSub:
|
||||
return Sub::make(lhs_new, rhs_new);
|
||||
case IRNodeType::kMul:
|
||||
return Mul::make(lhs_new, rhs_new);
|
||||
case IRNodeType::kDiv:
|
||||
return Div::make(lhs_new, rhs_new);
|
||||
case IRNodeType::kMod:
|
||||
return Mod::make(lhs_new, rhs_new);
|
||||
case IRNodeType::kMax:
|
||||
return Max::make(lhs_new, rhs_new, option);
|
||||
case IRNodeType::kMin:
|
||||
return Min::make(lhs_new, rhs_new, option);
|
||||
default:
|
||||
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
|
||||
return Expr();
|
||||
}
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Add* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Sub* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Mul* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Div* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Mod* v) {
|
||||
return mutate_binary_op(v, this);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Max* v) {
|
||||
return mutate_binary_op(v, this, v->propagate_nans());
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Min* v) {
|
||||
return mutate_binary_op(v, this, v->propagate_nans());
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const CompareSelect* v) {
|
||||
Expr lhs = v->lhs();
|
||||
Expr rhs = v->rhs();
|
||||
Expr lhs_new = lhs.accept_mutator(this);
|
||||
Expr rhs_new = rhs.accept_mutator(this);
|
||||
if (same_node(lhs, lhs_new) && same_node(rhs, rhs_new)) {
|
||||
return Expr(v);
|
||||
}
|
||||
return CompareSelect::make(lhs_new, rhs_new, v->compare_select_op());
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const IntImm* v) {
|
||||
return Expr(v);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const FloatImm* v) {
|
||||
return Expr(v);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Cast* v) {
|
||||
Expr src_value = v->src_value();
|
||||
Expr src_value_new = src_value.accept_mutator(this);
|
||||
if (same_node(src_value_new, v->src_value())) {
|
||||
return Expr(v);
|
||||
}
|
||||
return Cast::make(v->dtype(), src_value_new);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Variable* v) {
|
||||
return Expr(v);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Let* v) {
|
||||
Expr var = v->var();
|
||||
Expr value = v->value();
|
||||
Expr body = v->body();
|
||||
Expr var_new = var.accept_mutator(this);
|
||||
Expr value_new = value.accept_mutator(this);
|
||||
Expr body_new = body.accept_mutator(this);
|
||||
if (same_node(var, var_new) && same_node(value, value_new) &&
|
||||
same_node(body, body_new)) {
|
||||
return Expr(v);
|
||||
}
|
||||
return Let::make(var_new, value_new, body_new);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Ramp* v) {
|
||||
Expr base = v->base();
|
||||
Expr stride = v->stride();
|
||||
Expr base_new = base.accept_mutator(this);
|
||||
Expr stride_new = stride.accept_mutator(this);
|
||||
if (same_node(base, base_new) && same_node(stride, stride_new)) {
|
||||
return Expr(v);
|
||||
}
|
||||
return Ramp::make(base_new, stride_new, v->lanes());
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Load* v) {
|
||||
Dtype dtype = v->dtype();
|
||||
Var base_handle = v->base_handle();
|
||||
Expr index = v->index();
|
||||
Expr mask = v->mask();
|
||||
Expr base_handle_expr = base_handle.accept_mutator(this);
|
||||
Var base_handle_new = Var(base_handle_expr.AsNode<Variable>());
|
||||
Expr index_new = index.accept_mutator(this);
|
||||
Expr mask_new = mask.accept_mutator(this);
|
||||
if (same_node(base_handle, base_handle_new) && same_node(index, index_new) &&
|
||||
same_node(mask, mask_new)) {
|
||||
return Expr(v);
|
||||
}
|
||||
return Load::make(dtype, base_handle_new, index_new, mask_new);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const Broadcast* v) {
|
||||
Expr value = v->value();
|
||||
int lanes = v->lanes();
|
||||
Expr value_new = value.accept_mutator(this);
|
||||
if (same_node(value, value_new)) {
|
||||
return Expr(v);
|
||||
}
|
||||
return Broadcast::make(value_new, lanes);
|
||||
}
|
||||
|
||||
Expr IRMutator::mutate(const IfThenElse* v) {
|
||||
Expr condition = v->condition();
|
||||
Expr true_value = v->true_value();
|
||||
Expr false_value = v->false_value();
|
||||
Expr condition_new = condition.accept_mutator(this);
|
||||
Expr true_value_new = true_value.accept_mutator(this);
|
||||
Expr false_value_new = false_value.accept_mutator(this);
|
||||
if (same_node(condition, condition_new) &&
|
||||
same_node(true_value, true_value_new) &&
|
||||
same_node(false_value, false_value_new)) {
|
||||
return Expr(v);
|
||||
}
|
||||
|
||||
return IfThenElse::make(condition_new, true_value_new, false_value_new);
|
||||
}
|
||||
|
||||
Stmt IRMutator::mutate(const For* v) {
|
||||
Var var = v->var();
|
||||
Expr start = v->start();
|
||||
Expr stop = v->stop();
|
||||
Stmt body = v->body();
|
||||
LoopOptions loop_options = v->loop_options();
|
||||
Expr var_new_expr = var.accept_mutator(this);
|
||||
Var var_new = Var(var_new_expr.AsNode<Variable>());
|
||||
Expr start_new = start.accept_mutator(this);
|
||||
Expr stop_new = stop.accept_mutator(this);
|
||||
Stmt body_new = body.accept_mutator(this);
|
||||
if (same_node(var, var_new) && same_node(start, start_new) &&
|
||||
same_node(stop, stop_new) && same_node(body, body_new)) {
|
||||
return Stmt(v);
|
||||
}
|
||||
return For::make(var_new, start_new, stop_new, body_new, loop_options);
|
||||
}
|
||||
|
||||
Stmt IRMutator::mutate(const Block* v) {
|
||||
bool any_change = false;
|
||||
std::vector<Stmt> stmts;
|
||||
for (int i = 0; i < v->nstmts(); i++) {
|
||||
Stmt stmt = v->stmt(i);
|
||||
Stmt stmt_new = stmt.accept_mutator(this);
|
||||
if (!same_node(stmt, stmt_new)) {
|
||||
any_change = true;
|
||||
}
|
||||
stmts.push_back(stmt_new);
|
||||
}
|
||||
if (!any_change) {
|
||||
return Stmt(v);
|
||||
}
|
||||
return Block::make(stmts);
|
||||
}
|
||||
|
||||
Stmt IRMutator::mutate(const Store* v) {
|
||||
Var base_handle = v->base_handle();
|
||||
Expr index = v->index();
|
||||
Expr value = v->value();
|
||||
Expr mask = v->mask();
|
||||
Expr base_handle_expr = base_handle.accept_mutator(this);
|
||||
Var base_handle_new = Var(base_handle_expr.AsNode<Variable>());
|
||||
Expr index_new = index.accept_mutator(this);
|
||||
Expr value_new = value.accept_mutator(this);
|
||||
Expr mask_new = mask.accept_mutator(this);
|
||||
if (same_node(base_handle, base_handle_new) && same_node(index, index_new) &&
|
||||
same_node(value, value_new) && same_node(mask, mask_new)) {
|
||||
return Stmt(v);
|
||||
}
|
||||
return Store::make(base_handle_new, index_new, value_new, mask_new);
|
||||
}
|
||||
|
||||
Stmt IRMutator::mutate(const Allocate* v) {
|
||||
Var buffer_var_old = v->buffer_var();
|
||||
Var buffer_var_new =
|
||||
Var(buffer_var_old.accept_mutator(this).AsNode<Variable>());
|
||||
bool any_change = same_node(buffer_var_new, buffer_var_old);
|
||||
|
||||
std::vector<Expr> dims_old = v->dims();
|
||||
std::vector<Expr> dims_new(dims_old.size());
|
||||
for (size_t i = 0; i < dims_old.size(); i++) {
|
||||
dims_new[i] = dims_old[i].accept_mutator(this);
|
||||
any_change |= same_node(dims_new[i], dims_old[i]);
|
||||
}
|
||||
|
||||
if (!any_change) {
|
||||
return Stmt(v);
|
||||
}
|
||||
|
||||
return Allocate::make(buffer_var_new, v->dtype(), dims_new);
|
||||
}
|
||||
|
||||
Stmt IRMutator::mutate(const Free* v) {
|
||||
Var buffer_var_old = v->buffer_var();
|
||||
Var buffer_var_new =
|
||||
Var(buffer_var_old.accept_mutator(this).AsNode<Variable>());
|
||||
if (same_node(buffer_var_new, buffer_var_old)) {
|
||||
return Stmt(v);
|
||||
}
|
||||
|
||||
return Free::make(buffer_var_new);
|
||||
}
|
||||
|
||||
Stmt IRMutator::mutate(const Cond* v) {
|
||||
Expr cond_old = v->condition();
|
||||
Stmt true_old = v->true_stmt();
|
||||
Stmt false_old = v->false_stmt();
|
||||
|
||||
Expr cond_new = cond_old.accept_mutator(this);
|
||||
Stmt true_new = true_old.accept_mutator(this);
|
||||
Stmt false_new = false_old.accept_mutator(this);
|
||||
|
||||
if (same_node(cond_old, cond_new) && same_node(true_old, true_new) &&
|
||||
same_node(false_old, false_new)) {
|
||||
return Stmt(v);
|
||||
}
|
||||
return Cond::make(cond_new, true_new, false_new);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
68
torch/csrc/jit/tensorexpr/ir_mutator.h
Normal file
68
torch/csrc/jit/tensorexpr/ir_mutator.h
Normal file
@ -0,0 +1,68 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
class Add;
|
||||
class Sub;
|
||||
class Mul;
|
||||
class Div;
|
||||
class Mod;
|
||||
class Max;
|
||||
class Min;
|
||||
class CompareSelect;
|
||||
class IntImm;
|
||||
class FloatImm;
|
||||
class Cast;
|
||||
class Variable;
|
||||
class Let;
|
||||
class Ramp;
|
||||
class Load;
|
||||
class For;
|
||||
class Block;
|
||||
class Store;
|
||||
class Broadcast;
|
||||
class IfThenElse;
|
||||
class Expr;
|
||||
class Stmt;
|
||||
class BaseCallNode;
|
||||
class FunctionCall;
|
||||
class Allocate;
|
||||
class Free;
|
||||
class Cond;
|
||||
|
||||
class TORCH_API IRMutator {
|
||||
public:
|
||||
virtual ~IRMutator() {}
|
||||
virtual Expr mutate(const Add* v);
|
||||
virtual Expr mutate(const Sub* v);
|
||||
virtual Expr mutate(const Mul* v);
|
||||
virtual Expr mutate(const Div* v);
|
||||
virtual Expr mutate(const Mod* v);
|
||||
virtual Expr mutate(const Max* v);
|
||||
virtual Expr mutate(const Min* v);
|
||||
virtual Expr mutate(const CompareSelect* v);
|
||||
virtual Expr mutate(const IntImm* v);
|
||||
virtual Expr mutate(const FloatImm* v);
|
||||
virtual Expr mutate(const Cast* v);
|
||||
virtual Expr mutate(const Variable* v);
|
||||
virtual Expr mutate(const Let* v);
|
||||
virtual Expr mutate(const Ramp* v);
|
||||
virtual Expr mutate(const Load* v);
|
||||
virtual Expr mutate(const Broadcast* v);
|
||||
virtual Expr mutate(const IfThenElse* v);
|
||||
|
||||
virtual Stmt mutate(const For* v);
|
||||
virtual Stmt mutate(const Block* v);
|
||||
virtual Stmt mutate(const Store* v);
|
||||
|
||||
virtual Stmt mutate(const Allocate* v);
|
||||
virtual Stmt mutate(const Free* v);
|
||||
virtual Stmt mutate(const Cond* v);
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
126
torch/csrc/jit/tensorexpr/ir_visitor.cpp
Normal file
126
torch/csrc/jit/tensorexpr/ir_visitor.cpp
Normal file
@ -0,0 +1,126 @@
|
||||
#include "torch/csrc/jit/tensorexpr/ir_visitor.h"
|
||||
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
template <typename Op>
|
||||
static void visit_binary_op(const BinaryOpNode<Op>* v, IRVisitor* visitor) {
|
||||
v->lhs().accept(visitor);
|
||||
v->rhs().accept(visitor);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Add* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Sub* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Mul* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Div* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Mod* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Max* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Min* v) {
|
||||
visit_binary_op(v, this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const CompareSelect* v) {
|
||||
v->lhs().accept(this);
|
||||
v->rhs().accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const IntImm* v) {}
|
||||
void IRVisitor::visit(const FloatImm* v) {}
|
||||
void IRVisitor::visit(const Cast* v) {
|
||||
v->src_value().accept(this);
|
||||
}
|
||||
void IRVisitor::visit(const Variable* v) {}
|
||||
void IRVisitor::visit(const Let* v) {
|
||||
v->var().accept(this);
|
||||
v->value().accept(this);
|
||||
v->body().accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Ramp* v) {
|
||||
v->base().accept(this);
|
||||
v->stride().accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Load* v) {
|
||||
v->base_handle().accept(this);
|
||||
v->index().accept(this);
|
||||
v->mask().accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Store* v) {
|
||||
v->base_handle().accept(this);
|
||||
v->index().accept(this);
|
||||
v->value().accept(this);
|
||||
v->mask().accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Block* v) {
|
||||
for (int i = 0; i < v->nstmts(); i++) {
|
||||
v->stmt(i).accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const For* v) {
|
||||
v->var().accept(this);
|
||||
v->start().accept(this);
|
||||
v->stop().accept(this);
|
||||
v->body().accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Broadcast* v) {
|
||||
v->value().accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const IfThenElse* v) {
|
||||
v->condition().accept(this);
|
||||
v->true_value().accept(this);
|
||||
v->false_value().accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Allocate* v) {
|
||||
Var buffer_var = v->buffer_var();
|
||||
buffer_var.accept(this);
|
||||
std::vector<Expr> dims = v->dims();
|
||||
for (Expr& dim : dims) {
|
||||
dim.accept(this);
|
||||
}
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Free* v) {
|
||||
Var buffer_var = v->buffer_var();
|
||||
buffer_var.accept(this);
|
||||
}
|
||||
|
||||
void IRVisitor::visit(const Cond* v) {
|
||||
Expr condition = v->condition();
|
||||
Stmt true_stmt = v->true_stmt();
|
||||
Stmt false_stmt = v->false_stmt();
|
||||
condition.accept(this);
|
||||
true_stmt.accept(this);
|
||||
false_stmt.accept(this);
|
||||
}
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
71
torch/csrc/jit/tensorexpr/ir_visitor.h
Normal file
71
torch/csrc/jit/tensorexpr/ir_visitor.h
Normal file
@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
class Add;
|
||||
class Sub;
|
||||
class Mul;
|
||||
class Div;
|
||||
class Mod;
|
||||
class Max;
|
||||
class Min;
|
||||
class CompareSelect;
|
||||
class IntImm;
|
||||
class FloatImm;
|
||||
class Cast;
|
||||
class Variable;
|
||||
class Let;
|
||||
class Ramp;
|
||||
class Load;
|
||||
class For;
|
||||
class Block;
|
||||
class Store;
|
||||
class Broadcast;
|
||||
class IfThenElse;
|
||||
class BaseCallNode;
|
||||
class FunctionCall;
|
||||
class Allocate;
|
||||
class Free;
|
||||
class Cond;
|
||||
|
||||
class TORCH_API IRVisitor {
|
||||
public:
|
||||
virtual ~IRVisitor() {}
|
||||
virtual void visit(const Add* v);
|
||||
virtual void visit(const Sub* v);
|
||||
virtual void visit(const Mul* v);
|
||||
virtual void visit(const Div* v);
|
||||
virtual void visit(const Mod* v);
|
||||
virtual void visit(const Max* v);
|
||||
virtual void visit(const Min* v);
|
||||
virtual void visit(const CompareSelect* v);
|
||||
virtual void visit(const IntImm* v);
|
||||
virtual void visit(const FloatImm* v);
|
||||
virtual void visit(const Cast* v);
|
||||
virtual void visit(const Variable* v);
|
||||
virtual void visit(const Let* v);
|
||||
virtual void visit(const Ramp* v);
|
||||
virtual void visit(const Load* v);
|
||||
virtual void visit(const For* v);
|
||||
virtual void visit(const Block* v);
|
||||
virtual void visit(const Store* v);
|
||||
virtual void visit(const Broadcast* v);
|
||||
virtual void visit(const IfThenElse* v);
|
||||
|
||||
// BaseCallNode is the base class for all call nodes.
|
||||
// For any visitors that only needs the common behavior, only override this
|
||||
// function is enough. This is because all derived class handlers will call
|
||||
// this function by default.
|
||||
// Override the derived class handler only if the logic is more specific to
|
||||
// that.
|
||||
virtual void visit(const Allocate* v);
|
||||
virtual void visit(const Free* v);
|
||||
virtual void visit(const Cond* v);
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user