mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: A new version of the IR simplifier used by the jit/tensorexpr fuser. This is capable of simplifying expressions containing (shock) multiple variables, eg: ```(m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1``` Similar to the previous IR Simplifier it uses a two stage approach: 1. Traverse the tree combining subtree's of commutable operations in to a flat structure. In this implementation we have two intermediate Exprs: Term (expressing products of sub expressions) and Polynomial (expressing sums of sub expressions). 2. Traverse the tree expanding Term's and Polynomials into their component operators. Using the example above we execute with a process like this to simplify: ``` (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) # Using PolynomialTransformer: => Sub(Add(Mul(m, Mul(1, n_1)), Add(n, 1)), Add(Mul(m, Mul(1, n_1)), n)) => Sub(Polynomial(Term(m, n_1), n, 1), Polynomial(Term(m, n_1), n)) => Polynomial(Term(m, n_1), Term(-1, m, n_1), n, -n, 1) => Polynomial(1) # Using TermExpander => 1 ``` The IRSimplifier supports arithmetic simplifications of operators Add, Sub and Mul and constant folding of all binary Exprs and Intrinsics, but does not attempt expansion of multiplication of Polynomials to the canonical form since that generally leads to less efficient representations. It will do scalar factorization if it results in removal of operators, and will merge chains of multilane primitives (such as Broadcast and Ramp) down into a single operator. The ir_simplifier unit tests are a short tour of its capabilities. The existing simplifier has a bug where it will sometimes reorder operations on floating point types which are not associative. This causes (at least) the pyhpc equation_of_state benchmark to produce incorrect results. I have fixed that issue in this version and verified that that benchmark produces the same results with and without the simplifier. Tests: all cpp & py tensorexpr tests, and pyphc benchmark: ``` benchmarks.equation_of_state ============================ Running on CPU size backend calls mean stdev min 25% median 75% max Δ ------------------------------------------------------------------------------------------------------------------ 4,194,304 pytorch 10 0.246 0.002 0.243 0.245 0.246 0.248 0.250 1.000 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/35127 Differential Revision: D20624571 Pulled By: nickgg fbshipit-source-id: e49049377beee69e02dcf26eb922bef1447ae776
879 lines
21 KiB
C++
879 lines
21 KiB
C++
#pragma once
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <torch/csrc/jit/tensorexpr/exceptions.h>
|
|
#include <torch/csrc/jit/tensorexpr/expr.h>
|
|
#include <torch/csrc/jit/tensorexpr/stmt.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace tensorexpr {
|
|
|
|
enum CompareSelectOperation {
|
|
kEQ = 0,
|
|
kGT,
|
|
kGE,
|
|
kLT,
|
|
kLE,
|
|
kNE,
|
|
};
|
|
|
|
inline int getPrecedence(IRNodeType ty) {
|
|
// Match C++ operator precedence rules, since some pretty-print expressions to
|
|
// C++. SEE: https://en.cppreference.com/w/cpp/language/operator_precedence
|
|
switch (ty) {
|
|
case kPrimitive:
|
|
return 0;
|
|
case kCast:
|
|
return 2;
|
|
case kAdd:
|
|
case kSub:
|
|
return 6;
|
|
case kMul:
|
|
case kDiv:
|
|
case kMod:
|
|
return 5;
|
|
case kMax:
|
|
case kMin:
|
|
return 99;
|
|
case kAnd:
|
|
return 11;
|
|
case kOr:
|
|
return 13;
|
|
case kLshift:
|
|
case kRshift:
|
|
return 7;
|
|
case kXor:
|
|
return 12;
|
|
case kCompareSelect:
|
|
case kLet:
|
|
return 16;
|
|
default:
|
|
return 99;
|
|
}
|
|
}
|
|
|
|
class Buffer;
|
|
|
|
class Cast : public ExprNode<Cast> {
|
|
public:
|
|
const Expr* src_value() const {
|
|
return src_value_;
|
|
}
|
|
static ExprHandle make(Dtype dtype, const ExprHandle& src_value) {
|
|
return ExprHandle(new Cast(dtype, src_value.node()));
|
|
}
|
|
Cast(Dtype dtype, const Expr* src_value)
|
|
: ExprNodeBase(dtype, kCast), src_value_(src_value) {}
|
|
|
|
bool isConstant() const override {
|
|
return src_value_->isConstant();
|
|
}
|
|
|
|
private:
|
|
const Expr* src_value_;
|
|
};
|
|
|
|
template <typename T>
|
|
ExprHandle cast(const ExprHandle& src_value) {
|
|
return Cast::make(Dtype(ToDtype<T>(), src_value.dtype().lanes()), src_value);
|
|
}
|
|
|
|
// Represent the expression node for binary operators.
|
|
// A CRTP pattern to share common code among the operators.
|
|
template <typename Op>
|
|
class BinaryOpNode : public ExprNode<Op> {
|
|
public:
|
|
const Expr* lhs() const {
|
|
return this->lhs_;
|
|
}
|
|
const Expr* rhs() const {
|
|
return this->rhs_;
|
|
}
|
|
|
|
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) {
|
|
return ExprHandle(new Op(lhs.node(), rhs.node()));
|
|
}
|
|
|
|
BinaryOpNode(
|
|
const Expr* lhs_v,
|
|
const Expr* rhs_v,
|
|
IRNodeType expr_type,
|
|
ScalarType ret_type = ScalarType::None)
|
|
: ExprNode<Op>(
|
|
BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type),
|
|
expr_type),
|
|
lhs_(CastIfNeeded(lhs_v, ExprNode<Op>::dtype())),
|
|
rhs_(CastIfNeeded(rhs_v, ExprNode<Op>::dtype())) {}
|
|
|
|
private:
|
|
static const Expr* CastIfNeeded(const Expr* expr, Dtype dst_dtype) {
|
|
if (expr->dtype() == dst_dtype) {
|
|
return expr;
|
|
}
|
|
return Cast::make(dst_dtype, ExprHandle(expr)).node();
|
|
}
|
|
|
|
const Expr* lhs_;
|
|
const Expr* rhs_;
|
|
};
|
|
|
|
class Add : public BinaryOpNode<Add> {
|
|
public:
|
|
Add(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {}
|
|
};
|
|
|
|
class Sub : public BinaryOpNode<Sub> {
|
|
public:
|
|
Sub(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kSub) {}
|
|
};
|
|
|
|
class Mul : public BinaryOpNode<Mul> {
|
|
public:
|
|
Mul(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kMul) {}
|
|
};
|
|
|
|
class Div : public BinaryOpNode<Div> {
|
|
public:
|
|
Div(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {}
|
|
};
|
|
|
|
class Mod : public BinaryOpNode<Mod> {
|
|
public:
|
|
Mod(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kMod) {}
|
|
};
|
|
|
|
class And : public BinaryOpNode<And> {
|
|
public:
|
|
And(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kAnd) {
|
|
if (lhs->dtype().scalar_type() != ScalarType::Int) {
|
|
throw unsupported_dtype();
|
|
}
|
|
if (lhs->dtype() != rhs->dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
};
|
|
|
|
class Or : public BinaryOpNode<Or> {
|
|
public:
|
|
Or(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kOr) {
|
|
if (lhs->dtype().scalar_type() != ScalarType::Int) {
|
|
throw unsupported_dtype();
|
|
}
|
|
if (lhs->dtype() != rhs->dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
};
|
|
|
|
class Xor : public BinaryOpNode<Xor> {
|
|
public:
|
|
Xor(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kXor) {
|
|
if (lhs->dtype().scalar_type() != ScalarType::Int) {
|
|
throw unsupported_dtype();
|
|
}
|
|
if (lhs->dtype() != rhs->dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
};
|
|
|
|
class Lshift : public BinaryOpNode<Lshift> {
|
|
public:
|
|
Lshift(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kLshift) {
|
|
if (lhs->dtype().scalar_type() != ScalarType::Int) {
|
|
throw unsupported_dtype();
|
|
}
|
|
if (lhs->dtype() != rhs->dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
};
|
|
|
|
class Rshift : public BinaryOpNode<Rshift> {
|
|
public:
|
|
Rshift(const Expr* lhs, const Expr* rhs)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kRshift) {
|
|
if (lhs->dtype().scalar_type() != ScalarType::Int) {
|
|
throw unsupported_dtype();
|
|
}
|
|
if (lhs->dtype() != rhs->dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
};
|
|
|
|
class Max : public BinaryOpNode<Max> {
|
|
private:
|
|
bool propagate_nans_;
|
|
|
|
public:
|
|
Max(const Expr* lhs, const Expr* rhs, bool propagate_nans)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kMax),
|
|
propagate_nans_(propagate_nans) {}
|
|
|
|
bool propagate_nans() const {
|
|
return propagate_nans_;
|
|
}
|
|
|
|
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete;
|
|
static ExprHandle make(
|
|
const ExprHandle& lhs,
|
|
const ExprHandle& rhs,
|
|
bool propagate_nans) {
|
|
return ExprHandle(new Max(lhs.node(), rhs.node(), propagate_nans));
|
|
}
|
|
};
|
|
|
|
class Min : public BinaryOpNode<Min> {
|
|
private:
|
|
bool propagate_nans_;
|
|
|
|
public:
|
|
Min(const Expr* lhs, const Expr* rhs, bool propagate_nans)
|
|
: BinaryOpNode(lhs, rhs, IRNodeType::kMin),
|
|
propagate_nans_(propagate_nans) {}
|
|
|
|
bool propagate_nans() const {
|
|
return propagate_nans_;
|
|
}
|
|
|
|
static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete;
|
|
static ExprHandle make(
|
|
const ExprHandle& lhs,
|
|
const ExprHandle& rhs,
|
|
bool propagate_nans) {
|
|
return ExprHandle(new Min(lhs.node(), rhs.node(), propagate_nans));
|
|
}
|
|
};
|
|
|
|
// Encode typed immediate values e.g. IntImm, FloatImm.
|
|
#define IMM_DECLARE(Type, Name) \
|
|
class Name##Imm : public ExprNode<Name##Imm> { \
|
|
public: \
|
|
Name##Imm(Type value) \
|
|
: ExprNodeBase(k##Name, kPrimitive), value_(value) {} \
|
|
bool isConstant() const override { \
|
|
return true; \
|
|
} \
|
|
Type value() const { \
|
|
return value_; \
|
|
} \
|
|
static ExprHandle make(Type value) { \
|
|
return ExprHandle(new Name##Imm(value)); \
|
|
} \
|
|
\
|
|
private: \
|
|
Type value_; \
|
|
};
|
|
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE);
|
|
#undef IMM_DECLARE
|
|
|
|
// Get immediate by ScalarType.
|
|
template <typename T>
|
|
Expr* getImmediateByType(ScalarType immType, T initialVal) {
|
|
switch (immType) {
|
|
#define TYPE_CASE(Type, Name) \
|
|
case ScalarType::Name: \
|
|
return new Name##Imm(initialVal);
|
|
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
|
#undef TYPE_CASE
|
|
default:
|
|
throw unsupported_dtype();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
template <typename T>
|
|
Expr* getImmediateByType(Dtype dtype, T initialVal) {
|
|
return getImmediateByType<T>(dtype.scalar_type(), initialVal);
|
|
}
|
|
|
|
template <typename T>
|
|
T immediateAs(const Expr* e) {
|
|
#define TYPE_CASE(Type, Name) \
|
|
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
|
|
return imm->value(); \
|
|
}
|
|
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
|
#undef TYPE_CASE
|
|
throw unsupported_dtype();
|
|
return 0;
|
|
}
|
|
|
|
template <typename T>
|
|
bool immediateEquals(const Expr* e, T val) {
|
|
#define TYPE_CASE(Type, Name) \
|
|
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
|
|
return imm->value() == val; \
|
|
}
|
|
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
|
#undef TYPE_CASE
|
|
throw unsupported_dtype();
|
|
return false;
|
|
}
|
|
|
|
template <typename T>
|
|
bool immediateIsNegative(const T* e) {
|
|
#define TYPE_CASE(Type, Name) \
|
|
if (const Name##Imm* imm = dynamic_cast<const Name##Imm*>(e)) { \
|
|
return imm->value() < 0; \
|
|
}
|
|
AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE);
|
|
#undef TYPE_CASE
|
|
return false;
|
|
}
|
|
|
|
// Creates a new Expr of the given type with the provided lhs and rhs.
|
|
static const Expr* newBinaryOpOfType(
|
|
IRNodeType expr_type,
|
|
const Expr* lhs,
|
|
const Expr* rhs,
|
|
bool option) {
|
|
switch (expr_type) {
|
|
case IRNodeType::kAdd:
|
|
return new Add(lhs, rhs);
|
|
case IRNodeType::kSub:
|
|
return new Sub(lhs, rhs);
|
|
case IRNodeType::kMul:
|
|
return new Mul(lhs, rhs);
|
|
case IRNodeType::kDiv:
|
|
return new Div(lhs, rhs);
|
|
case IRNodeType::kMod:
|
|
return new Mod(lhs, rhs);
|
|
case IRNodeType::kMax:
|
|
return new Max(lhs, rhs, option);
|
|
case IRNodeType::kMin:
|
|
return new Min(lhs, rhs, option);
|
|
case IRNodeType::kAnd:
|
|
return new And(lhs, rhs);
|
|
case IRNodeType::kXor:
|
|
return new Xor(lhs, rhs);
|
|
case IRNodeType::kLshift:
|
|
return new Lshift(lhs, rhs);
|
|
case IRNodeType::kRshift:
|
|
return new Rshift(lhs, rhs);
|
|
default:
|
|
LOG(FATAL) << "unsupported expr_type: " << static_cast<int>(expr_type);
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
// Bind the value to the var and evaluate the body.
|
|
class Let : public ExprNode<Let> {
|
|
public:
|
|
const Expr* var() const {
|
|
return var_;
|
|
}
|
|
const Expr* value() const {
|
|
return value_;
|
|
}
|
|
const Expr* body() const {
|
|
return body_;
|
|
}
|
|
|
|
static ExprHandle make(
|
|
const ExprHandle& var,
|
|
const ExprHandle& value,
|
|
const ExprHandle& body) {
|
|
return ExprHandle(new Let(var.node(), value.node(), body.node()));
|
|
}
|
|
|
|
Let(const Expr* var, const Expr* value, const Expr* body)
|
|
: ExprNodeBase(body->dtype(), kLet),
|
|
var_(var),
|
|
value_(value),
|
|
body_(body) {}
|
|
|
|
private:
|
|
const Expr* var_;
|
|
const Expr* value_;
|
|
const Expr* body_;
|
|
};
|
|
|
|
// Represents a ramp vector node:
|
|
// [base, base + 1 * stride, ... , base + (lanes - 1) * stride]
|
|
class Ramp : public ExprNode<Ramp> {
|
|
public:
|
|
const Expr* base() const {
|
|
return base_;
|
|
}
|
|
const Expr* stride() const {
|
|
return stride_;
|
|
}
|
|
static ExprHandle make(
|
|
const ExprHandle& base,
|
|
const ExprHandle& stride,
|
|
int lanes) {
|
|
return ExprHandle(new Ramp(base.node(), stride.node(), lanes));
|
|
}
|
|
int lanes() const {
|
|
return lanes_;
|
|
}
|
|
|
|
Ramp(const Expr* base, const Expr* stride, int lanes)
|
|
: ExprNodeBase(Dtype(base->dtype(), lanes), kRamp),
|
|
base_(base),
|
|
stride_(stride),
|
|
lanes_(lanes) {
|
|
if (stride->dtype() != base->dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
|
|
private:
|
|
const Expr* base_;
|
|
const Expr* stride_;
|
|
int lanes_;
|
|
};
|
|
|
|
class TORCH_API Load : public ExprNode<Load> {
|
|
public:
|
|
const Var* base_handle() const {
|
|
return base_handle_;
|
|
}
|
|
const Expr* index() const {
|
|
return index_;
|
|
}
|
|
const Expr* mask() const {
|
|
return mask_;
|
|
}
|
|
static ExprHandle make(
|
|
const Buffer& buffer,
|
|
const ExprHandle& index,
|
|
const ExprHandle& mask) {
|
|
return ExprHandle(new Load(buffer, index.node(), mask.node()));
|
|
}
|
|
static ExprHandle make(
|
|
Dtype dtype,
|
|
const VarHandle& base_handle,
|
|
const ExprHandle& index,
|
|
const ExprHandle& mask) {
|
|
return ExprHandle(
|
|
new Load(dtype, base_handle.node(), index.node(), mask.node()));
|
|
}
|
|
|
|
Load(const Buffer& buffer, const Expr* index, const Expr* mask);
|
|
Load(
|
|
Dtype dtype,
|
|
const Var* base_handle,
|
|
const Expr* index,
|
|
const Expr* mask);
|
|
|
|
private:
|
|
const Var* base_handle_;
|
|
const Expr* index_;
|
|
const Expr* mask_;
|
|
};
|
|
|
|
class Broadcast : public ExprNode<Broadcast> {
|
|
public:
|
|
const Expr* value() const {
|
|
return value_;
|
|
}
|
|
int lanes() const {
|
|
return lanes_;
|
|
}
|
|
static ExprHandle make(const ExprHandle& value, int lanes) {
|
|
return ExprHandle(new Broadcast(value.node(), lanes));
|
|
}
|
|
Broadcast(const Expr* value, int lanes)
|
|
: ExprNodeBase(Dtype(value->dtype(), lanes), kBroadcast),
|
|
value_(value),
|
|
lanes_(lanes) {}
|
|
|
|
private:
|
|
const Expr* value_;
|
|
int lanes_;
|
|
};
|
|
|
|
class IfThenElse : public ExprNode<IfThenElse> {
|
|
public:
|
|
const Expr* condition() const {
|
|
return condition_;
|
|
}
|
|
|
|
// Lazily evaluated only if condition is true
|
|
const Expr* true_value() const {
|
|
return true_;
|
|
}
|
|
|
|
// Lazily evaluated only if condition is false
|
|
const Expr* false_value() const {
|
|
return false_;
|
|
}
|
|
|
|
static ExprHandle make(
|
|
const ExprHandle& c,
|
|
const ExprHandle& t,
|
|
const ExprHandle& f) {
|
|
return ExprHandle(new IfThenElse(c.node(), t.node(), f.node()));
|
|
}
|
|
|
|
IfThenElse(const Expr* c, const Expr* t, const Expr* f)
|
|
: ExprNodeBase(t->dtype()), condition_(c), true_(t), false_(f) {
|
|
if (c->dtype().scalar_type() != ScalarType::Int) {
|
|
throw unsupported_dtype();
|
|
}
|
|
if (c->dtype().lanes() != 1) {
|
|
throw unsupported_dtype();
|
|
}
|
|
if (t->dtype() != f->dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
|
|
private:
|
|
const Expr* condition_;
|
|
const Expr* true_;
|
|
const Expr* false_;
|
|
};
|
|
|
|
class BaseCallNode : public Expr {
|
|
public:
|
|
enum CallType {
|
|
kIntrinsics,
|
|
kFunctionCall,
|
|
};
|
|
|
|
int nparams() const {
|
|
return params_.size();
|
|
}
|
|
|
|
const Expr* param(int index) const {
|
|
return params_[index];
|
|
}
|
|
const std::vector<const Expr*>& params() const {
|
|
return params_;
|
|
}
|
|
|
|
virtual std::string func_name() const = 0;
|
|
|
|
CallType call_type() const {
|
|
return call_type_;
|
|
}
|
|
|
|
protected:
|
|
BaseCallNode(
|
|
Dtype dtype,
|
|
CallType call_type,
|
|
const std::vector<const Expr*>& params)
|
|
: Expr(dtype), call_type_(call_type), params_(params) {}
|
|
|
|
private:
|
|
// The handler for the default ir_mutator to make a copy of this node with new
|
|
// params.
|
|
virtual const Expr* DefaultMutator(
|
|
const std::vector<const Expr*>& new_params) const = 0;
|
|
|
|
template <class U, class B>
|
|
friend class ExprNode;
|
|
friend class IRMutator;
|
|
|
|
CallType call_type_;
|
|
std::vector<const Expr*> params_;
|
|
};
|
|
|
|
template <typename Op>
|
|
class CallNode : public ExprNode<Op, BaseCallNode> {
|
|
public:
|
|
using BaseClass = ExprNode<Op, BaseCallNode>;
|
|
using BaseClass::BaseClass;
|
|
};
|
|
|
|
class TORCH_API CompareSelect : public ExprNode<CompareSelect> {
|
|
public:
|
|
CompareSelectOperation compare_select_op() const {
|
|
return compare_op_;
|
|
}
|
|
const Expr* lhs() const {
|
|
return this->lhs_;
|
|
}
|
|
const Expr* rhs() const {
|
|
return this->rhs_;
|
|
}
|
|
const Expr* ret_val1() const {
|
|
return this->ret_val1_;
|
|
}
|
|
const Expr* ret_val2() const {
|
|
return this->ret_val2_;
|
|
}
|
|
|
|
static ExprHandle make(
|
|
const ExprHandle& lhs,
|
|
const ExprHandle& rhs,
|
|
CompareSelectOperation cmp_op) {
|
|
if (lhs.dtype() != rhs.dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
return ExprHandle(new CompareSelect(
|
|
lhs.node(),
|
|
rhs.node(),
|
|
IntImm::make(1).node(),
|
|
IntImm::make(0).node(),
|
|
cmp_op));
|
|
}
|
|
|
|
static ExprHandle make(
|
|
const ExprHandle& lhs,
|
|
const ExprHandle& rhs,
|
|
const ExprHandle& ret_val1,
|
|
const ExprHandle& ret_val2,
|
|
CompareSelectOperation cmp_op) {
|
|
if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
return ExprHandle(new CompareSelect(
|
|
lhs.node(), rhs.node(), ret_val1.node(), ret_val2.node(), cmp_op));
|
|
}
|
|
|
|
private:
|
|
const Expr* lhs_;
|
|
const Expr* rhs_;
|
|
const Expr* ret_val1_;
|
|
const Expr* ret_val2_;
|
|
CompareSelectOperation compare_op_;
|
|
|
|
CompareSelect(
|
|
const Expr* lhs,
|
|
const Expr* rhs,
|
|
const Expr* ret_val1,
|
|
const Expr* ret_val2,
|
|
CompareSelectOperation cmp_op)
|
|
: ExprNodeBase(ret_val1->dtype()),
|
|
lhs_(lhs),
|
|
rhs_(rhs),
|
|
ret_val1_(ret_val1),
|
|
ret_val2_(ret_val2),
|
|
compare_op_(cmp_op) {
|
|
if (ret_val1->dtype() != ret_val2->dtype()) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
};
|
|
|
|
enum IntrinsicsOp {
|
|
kSin,
|
|
kCos,
|
|
kTan,
|
|
kAsin,
|
|
kAcos,
|
|
kAtan,
|
|
kAtan2,
|
|
kSinh,
|
|
kCosh,
|
|
kTanh,
|
|
kExp,
|
|
kExpm1,
|
|
kFabs,
|
|
kLog,
|
|
kLog2,
|
|
kLog10,
|
|
kLog1p,
|
|
kErf,
|
|
kErfc,
|
|
kSqrt,
|
|
kRsqrt,
|
|
kPow,
|
|
kCeil,
|
|
kFloor,
|
|
kRound,
|
|
kTrunc,
|
|
kFmod,
|
|
kRemainder,
|
|
kLgamma,
|
|
kFrac,
|
|
kRand, // We need more discussions on this. Should we consider stateful?
|
|
};
|
|
|
|
class Intrinsics : public CallNode<Intrinsics> {
|
|
public:
|
|
static ExprHandle make(IntrinsicsOp op_type, const ExprHandle& v1) {
|
|
return ExprHandle(new Intrinsics(op_type, v1.node()));
|
|
}
|
|
|
|
static ExprHandle make(
|
|
IntrinsicsOp op_type,
|
|
const ExprHandle& v1,
|
|
const ExprHandle& v2) {
|
|
return ExprHandle(new Intrinsics(op_type, v1.node(), v2.node()));
|
|
}
|
|
|
|
static ExprHandle make(
|
|
IntrinsicsOp op_type,
|
|
const std::vector<ExprHandle>& params) {
|
|
std::vector<const Expr*> params_nodes(params.size());
|
|
for (size_t i = 0; i < params.size(); i++) {
|
|
params_nodes[i] = params[i].node();
|
|
}
|
|
return ExprHandle(new Intrinsics(op_type, params_nodes));
|
|
}
|
|
|
|
static ExprHandle make(IntrinsicsOp op_type, Dtype dtype) {
|
|
return ExprHandle(new Intrinsics(op_type, dtype));
|
|
}
|
|
|
|
IntrinsicsOp op_type() const {
|
|
return op_type_;
|
|
}
|
|
|
|
std::string func_name() const override {
|
|
switch (op_type()) {
|
|
case kSin:
|
|
return "sin";
|
|
case kCos:
|
|
return "cos";
|
|
case kTan:
|
|
return "tan";
|
|
case kAsin:
|
|
return "asin";
|
|
case kAcos:
|
|
return "acos";
|
|
case kAtan:
|
|
return "atan";
|
|
case kAtan2:
|
|
return "atan2";
|
|
case kSinh:
|
|
return "sinh";
|
|
case kCosh:
|
|
return "cosh";
|
|
case kTanh:
|
|
return "tanh";
|
|
case kExp:
|
|
return "exp";
|
|
case kFabs:
|
|
return "fabs";
|
|
case kLog:
|
|
return "log";
|
|
case kLog2:
|
|
return "log2";
|
|
case kLog10:
|
|
return "log10";
|
|
case kLog1p:
|
|
return "log1p";
|
|
case kErf:
|
|
return "erf";
|
|
case kSqrt:
|
|
return "sqrt";
|
|
case kRsqrt:
|
|
return "rsqrt";
|
|
case kPow:
|
|
return "pow";
|
|
case kCeil:
|
|
return "ceil";
|
|
case kFloor:
|
|
return "floor";
|
|
case kRound:
|
|
return "round";
|
|
case kTrunc:
|
|
return "trunc";
|
|
case kRand:
|
|
return "rand";
|
|
case kFmod:
|
|
return "fmod";
|
|
case kRemainder:
|
|
return "remainder";
|
|
case kLgamma:
|
|
return "lgamma";
|
|
case kExpm1:
|
|
return "expm1";
|
|
case kErfc:
|
|
return "erfc";
|
|
case kFrac:
|
|
return "frac";
|
|
default:
|
|
throw std::runtime_error(
|
|
"invalid op_type: " + std::to_string(op_type()));
|
|
}
|
|
}
|
|
using BaseClass = CallNode<Intrinsics>;
|
|
|
|
Intrinsics(IntrinsicsOp op_type, Dtype dtype)
|
|
: BaseClass(IntrinsicsDtype(op_type, dtype), kIntrinsics, {}),
|
|
op_type_(op_type) {
|
|
if (OpArgCount(op_type) != 0) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
|
|
Intrinsics(IntrinsicsOp op_type, const Expr* v1)
|
|
: BaseClass(IntrinsicsDtype(op_type, v1->dtype()), kIntrinsics, {v1}),
|
|
op_type_(op_type) {
|
|
if (OpArgCount(op_type) != 1) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
|
|
Intrinsics(IntrinsicsOp op_type, const Expr* v1, const Expr* v2)
|
|
: BaseClass(
|
|
IntrinsicsDtype(op_type, v1->dtype(), v2->dtype()),
|
|
kIntrinsics,
|
|
{v1, v2}),
|
|
op_type_(op_type) {
|
|
if (OpArgCount(op_type) != 2) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
|
|
Intrinsics(IntrinsicsOp op_type, const std::vector<const Expr*>& params)
|
|
: BaseClass(IntrinsicsDtype(op_type, params), kIntrinsics, params),
|
|
op_type_(op_type) {
|
|
if (OpArgCount(op_type) != nparams()) {
|
|
throw malformed_input();
|
|
}
|
|
}
|
|
|
|
bool isPure() const {
|
|
return op_type_ != kRand;
|
|
}
|
|
|
|
private:
|
|
TORCH_API static int OpArgCount(IntrinsicsOp op_type);
|
|
|
|
const Expr* DefaultMutator(
|
|
const std::vector<const Expr*>& new_params) const override {
|
|
return new Intrinsics(this->op_type(), new_params);
|
|
}
|
|
|
|
TORCH_API static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1);
|
|
TORCH_API static Dtype IntrinsicsDtype(
|
|
IntrinsicsOp op_type,
|
|
Dtype dt1,
|
|
Dtype dt2);
|
|
TORCH_API static Dtype IntrinsicsDtype(
|
|
IntrinsicsOp op_type,
|
|
const std::vector<const Expr*>& params);
|
|
|
|
IntrinsicsOp op_type_;
|
|
};
|
|
|
|
class Polynomial;
|
|
class Term;
|
|
|
|
class FunctionCall;
|
|
|
|
TORCH_API std::vector<const Expr*> ExprHandleVectorToExprVector(
|
|
const std::vector<ExprHandle>&);
|
|
TORCH_API std::vector<ExprHandle> ExprVectorToExprHandleVector(
|
|
const std::vector<const Expr*>&);
|
|
TORCH_API std::vector<const Var*> VarHandleVectorToVarVector(
|
|
const std::vector<VarHandle>&);
|
|
TORCH_API std::vector<VarHandle> VarVectorToVarHandleVector(
|
|
const std::vector<const Var*>&);
|
|
|
|
} // namespace tensorexpr
|
|
} // namespace jit
|
|
} // namespace torch
|