mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TensorExpr] Add a fuser pass based on tensor expressions. (#34226)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34226 LLVM and Cuda backends are added in subsequent PRs, so at this point the fuser is pretty useless, but it still can be tested and its logic is not going to change with addition of the codegens. Differential Revision: D20251838 Test Plan: Imported from OSS Pulled By: ZolotukhinM fbshipit-source-id: 82b0d221fa89904ed526689d02a6c7676a8ce8de
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e31d462e92
commit
42b2c8c65d
@ -481,6 +481,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_printer.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/kernel.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/schedule.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/tensor.cpp
|
||||
|
1047
test/test_tensorexpr.py
Normal file
1047
test/test_tensorexpr.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -205,6 +205,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/tensorexpr/ir_mutator.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_printer.cpp",
|
||||
"torch/csrc/jit/tensorexpr/ir_visitor.cpp",
|
||||
"torch/csrc/jit/tensorexpr/kernel.cpp",
|
||||
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
|
||||
"torch/csrc/jit/tensorexpr/schedule.cpp",
|
||||
"torch/csrc/jit/tensorexpr/tensor.cpp",
|
||||
|
@ -1,17 +1,23 @@
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/autograd/record_function.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/runtime/operator_options.h>
|
||||
#include <torch/csrc/jit/passes/pass_manager.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/pass_manager.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/operator_options.h>
|
||||
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
static bool texpr_fuser_enabled = true;
|
||||
void setTensorExprFuserEnabled(bool val) {
|
||||
texpr_fuser_enabled = val;
|
||||
}
|
||||
|
||||
const Symbol& getTensorExprSymbol() {
|
||||
static Symbol s = Symbol::fromQualString("tensorexpr::Group");
|
||||
return s;
|
||||
@ -36,9 +42,88 @@ value_list sortReverseTopological(
|
||||
return result;
|
||||
}
|
||||
|
||||
bool isSupported(Node* node) {
|
||||
// TODO:
|
||||
switch (node->kind()) {
|
||||
case aten::add:
|
||||
case aten::_cast_Float:
|
||||
case aten::type_as:
|
||||
case aten::sub:
|
||||
case aten::mul:
|
||||
case aten::div:
|
||||
case aten::eq:
|
||||
case aten::ne:
|
||||
case aten::ge:
|
||||
case aten::gt:
|
||||
case aten::le:
|
||||
case aten::lt:
|
||||
case aten::min:
|
||||
case aten::max:
|
||||
case aten::pow:
|
||||
case aten::clamp:
|
||||
case aten::lerp:
|
||||
case aten::log10:
|
||||
case aten::log:
|
||||
case aten::log2:
|
||||
case aten::exp:
|
||||
case aten::erf:
|
||||
case aten::erfc:
|
||||
case aten::fmod:
|
||||
case aten::cos:
|
||||
case aten::sin:
|
||||
case aten::tan:
|
||||
case aten::acos:
|
||||
case aten::asin:
|
||||
case aten::atan:
|
||||
case aten::atan2:
|
||||
case aten::cosh:
|
||||
case aten::sinh:
|
||||
case aten::tanh:
|
||||
case aten::sqrt:
|
||||
case aten::rsqrt:
|
||||
case aten::abs:
|
||||
case aten::floor:
|
||||
case aten::ceil:
|
||||
case aten::round:
|
||||
case aten::trunc:
|
||||
case aten::threshold:
|
||||
case aten::remainder:
|
||||
case prim::ConstantChunk:
|
||||
case aten::cat:
|
||||
case prim::ListConstruct:
|
||||
case aten::sigmoid:
|
||||
case aten::relu:
|
||||
case aten::addcmul:
|
||||
case aten::neg:
|
||||
case aten::reciprocal:
|
||||
case aten::expm1:
|
||||
case aten::lgamma:
|
||||
case aten::slice:
|
||||
case aten::unsqueeze:
|
||||
case aten::frac:
|
||||
case aten::rand_like:
|
||||
case aten::_sigmoid_backward:
|
||||
case aten::_tanh_backward:
|
||||
case aten::__and__:
|
||||
case aten::__or__:
|
||||
case aten::__xor__:
|
||||
case aten::__lshift__:
|
||||
case aten::__rshift__:
|
||||
case aten::where:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool canHandle(Node* node, AliasDb& aliasDb) {
|
||||
// TODO: actually support some ops
|
||||
return false;
|
||||
if (node->kind() == prim::Constant) {
|
||||
return true;
|
||||
}
|
||||
if (node->kind() == prim::Loop) {
|
||||
return false; // TODO
|
||||
}
|
||||
return isSupported(node);
|
||||
}
|
||||
|
||||
#define REQ(cond) \
|
||||
@ -65,6 +150,36 @@ bool canMerge(Node* consumer, Node* producer, AliasDb& aliasDb) {
|
||||
// Alias checks
|
||||
REQ(aliasDb.couldMoveAfterTopologically(consumer, producer));
|
||||
|
||||
// Ops that return aliases can only be folded if this is the only use.
|
||||
if (producer->kind() == aten::slice || producer->kind() == aten::unsqueeze ||
|
||||
producer->kind() == prim::ConstantChunk) {
|
||||
for (auto& use : producer->output(0)->uses()) {
|
||||
REQ(use.user == consumer);
|
||||
}
|
||||
}
|
||||
|
||||
if (!consumer->hasAttribute(attr::Subgraph) &&
|
||||
consumer->kind() != getTensorExprSymbol()) {
|
||||
// Don't initiate a fusion group from prim::ListConstruct
|
||||
REQ(consumer->kind() != prim::ListConstruct);
|
||||
REQ(consumer->kind() != aten::slice);
|
||||
REQ(consumer->kind() != aten::unsqueeze);
|
||||
REQ(consumer->kind() != prim::ConstantChunk);
|
||||
|
||||
// Don't initiate a fusion group just for a constant operand
|
||||
REQ(producer->kind() != prim::Constant);
|
||||
}
|
||||
|
||||
if (producer->kind() == aten::cat) {
|
||||
REQ(producer->inputs()[0]->node()->kind() == prim::ListConstruct);
|
||||
REQ(producer->inputs()[0]->uses().size() == 1);
|
||||
REQ(producer->inputs()[1]->node()->kind() == prim::Constant);
|
||||
} else if (consumer->kind() == aten::cat) {
|
||||
REQ(consumer->inputs()[0]->node()->kind() == prim::ListConstruct);
|
||||
REQ(consumer->inputs()[0]->uses().size() == 1);
|
||||
REQ(consumer->inputs()[1]->node()->kind() == prim::Constant);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
#undef REQ
|
||||
@ -73,7 +188,10 @@ Node* getOrCreateTensorExprSubgraph(Node* n) {
|
||||
if (n->hasAttribute(attr::Subgraph) && n->kind() == getTensorExprSymbol()) {
|
||||
return n;
|
||||
}
|
||||
return SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol());
|
||||
auto te_group =
|
||||
SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol());
|
||||
GRAPH_UPDATE("getOrCreateTensorExprSubgraph: ", *te_group);
|
||||
return te_group;
|
||||
}
|
||||
|
||||
c10::optional<Node*> tryMerge(
|
||||
@ -82,9 +200,9 @@ c10::optional<Node*> tryMerge(
|
||||
AliasDb& aliasDb) {
|
||||
GRAPH_DEBUG(
|
||||
"Trying producer ",
|
||||
producer->kind().toQualString(),
|
||||
getHeader(producer),
|
||||
" and consumer ",
|
||||
consumer->kind().toQualString(),
|
||||
getHeader(consumer),
|
||||
":\n");
|
||||
|
||||
if (!canMerge(consumer, producer, aliasDb)) {
|
||||
@ -93,8 +211,24 @@ c10::optional<Node*> tryMerge(
|
||||
|
||||
consumer = getOrCreateTensorExprSubgraph(consumer);
|
||||
|
||||
aliasDb.moveAfterTopologicallyValid(consumer, producer);
|
||||
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
|
||||
if (producer->kind() == aten::cat) {
|
||||
Node* listconstruct = producer->inputs()[0]->node();
|
||||
|
||||
aliasDb.moveAfterTopologicallyValid(consumer, producer);
|
||||
GRAPH_UPDATE(
|
||||
"Merging ", getHeader(producer), " into ", getHeader(consumer));
|
||||
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
|
||||
|
||||
aliasDb.moveAfterTopologicallyValid(consumer, listconstruct);
|
||||
GRAPH_UPDATE(
|
||||
"Merging ", getHeader(listconstruct), " into ", getHeader(consumer));
|
||||
SubgraphUtils::mergeNodeIntoSubgraph(listconstruct, consumer);
|
||||
} else {
|
||||
aliasDb.moveAfterTopologicallyValid(consumer, producer);
|
||||
GRAPH_UPDATE(
|
||||
"Merging ", getHeader(producer), " into ", getHeader(consumer));
|
||||
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
|
||||
}
|
||||
|
||||
return consumer;
|
||||
}
|
||||
@ -120,26 +254,10 @@ std::pair<graph_node_list::iterator, bool> scanNode(
|
||||
return {++(++iter), false};
|
||||
}
|
||||
|
||||
Operation createTensorExprOp(const Node* node) {
|
||||
// TODO: actually compile the fusion group.
|
||||
return [](Stack& stack) {
|
||||
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
||||
c10::AliasAnalysisKind getAliasAnalysisOption(AliasAnalysisKind k) {
|
||||
return k;
|
||||
}
|
||||
|
||||
RegisterOperators TensorExprOps({
|
||||
torch::jit::Operator(
|
||||
getTensorExprSymbol(),
|
||||
createTensorExprOp,
|
||||
getAliasAnalysisOption(AliasAnalysisKind::PURE_FUNCTION)),
|
||||
});
|
||||
|
||||
void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
|
||||
if (!texpr_fuser_enabled) {
|
||||
return;
|
||||
}
|
||||
GRAPH_DUMP("Before TExprFuser: ", graph);
|
||||
|
||||
// Get rid of dead code so that we don't waste effort fusing it.
|
||||
@ -192,6 +310,23 @@ void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
|
||||
GRAPH_DUMP("After TExprFuser: ", graph);
|
||||
}
|
||||
|
||||
Operation createTensorExprOp(const Node* node) {
|
||||
auto kernel =
|
||||
std::make_shared<tensorexpr::TensorExprKernel>(*node->g(attr::Subgraph));
|
||||
return [kernel](Stack& stack) {
|
||||
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
|
||||
kernel->run(stack);
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
||||
RegisterOperators TensorExprOps({
|
||||
torch::jit::Operator(
|
||||
getTensorExprSymbol(),
|
||||
createTensorExprOp,
|
||||
AliasAnalysisKind::PURE_FUNCTION),
|
||||
});
|
||||
|
||||
void registerTensorExprFuser() {
|
||||
static bool already_registered = false;
|
||||
if (!already_registered) {
|
||||
|
@ -14,5 +14,7 @@ TORCH_API void fuseTensorExprs(std::shared_ptr<Graph>& graph);
|
||||
// Register TensorExpressions-based fuser in custom passes.
|
||||
TORCH_API void registerTensorExprFuser();
|
||||
|
||||
TORCH_API void setTensorExprFuserEnabled(bool val);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -59,6 +59,7 @@
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/python/python_tree_views.h>
|
||||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
@ -405,6 +406,15 @@ void initJITBindings(PyObject* module) {
|
||||
}
|
||||
return nullptr;
|
||||
})
|
||||
.def(
|
||||
"_jit_get_trigger_value",
|
||||
[](const std::string& trigger_name) {
|
||||
using namespace torch::jit::tensorexpr;
|
||||
ExecutionTrigger* trigger =
|
||||
ExecutionTriggerList::GetInstance().FindByName(trigger_name);
|
||||
return trigger->value();
|
||||
})
|
||||
.def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
|
||||
.def(
|
||||
"_jit_fuser_get_fused_kernel_code",
|
||||
[](Graph& g, std::vector<at::Tensor> inps) {
|
||||
|
@ -4,7 +4,9 @@ namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
RegisterCodeGen<SimpleIREvaluator> reg("simple_ir_eval");
|
||||
DEFINE_TRIGGER(simple_ir_eval_executed);
|
||||
|
||||
RegisterCodeGen<SimpleIREvaluator> ir_eval_codegen_reg("simple_ir_eval");
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <c10/util/Logging.h>
|
||||
#include "torch/csrc/jit/tensorexpr/buffer.h"
|
||||
#include "torch/csrc/jit/tensorexpr/codegen.h"
|
||||
#include "torch/csrc/jit/tensorexpr/execution_counter.h"
|
||||
#include "torch/csrc/jit/tensorexpr/function.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir.h"
|
||||
#include "torch/csrc/jit/tensorexpr/ir_printer.h"
|
||||
@ -17,6 +18,8 @@ namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
DECLARE_TRIGGER(simple_ir_eval_executed);
|
||||
|
||||
class Value {
|
||||
public:
|
||||
Value() : dtype_(kInt) {
|
||||
@ -106,6 +109,7 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
|
||||
eval_context_.clear();
|
||||
buffer_mapping_.clear();
|
||||
internal_buffers_.clear();
|
||||
USE_TRIGGER(simple_ir_eval_executed);
|
||||
}
|
||||
|
||||
void bind(const BufferArg& buf, const CallArg& data) {
|
||||
|
118
torch/csrc/jit/tensorexpr/execution_counter.h
Normal file
118
torch/csrc/jit/tensorexpr/execution_counter.h
Normal file
@ -0,0 +1,118 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch/csrc/WindowsTorchApiMacro.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
/*
|
||||
ExecutionTrigger and ExecutionCounter builds instrumentation counters so
|
||||
underlying functionalities can be checked.
|
||||
|
||||
In the code to be instrumented:
|
||||
|
||||
// worker.cpp
|
||||
DEFINE_TRIGGER(useful_work_done); // this defines a trigger "useful_work_done"
|
||||
void run() {
|
||||
USE_TRIGGER(useful_work_done); // this triggers the underlying counter
|
||||
// in "useful_work_done"
|
||||
}
|
||||
|
||||
// in C++ client.cpp
|
||||
|
||||
DECLARE_TRIGGER(useful_work_done); // Optional: this declares a trigger that
|
||||
// will be defined elsewhere
|
||||
ExecutionCounter counter(useful_work_done); // This starts the counter from the
|
||||
// underlying trigger.
|
||||
... call run() ...
|
||||
counter.elapsed_value(); // this returns the incremented value from the
|
||||
// trigger since the creation of the counter
|
||||
|
||||
// in Python client.py
|
||||
counter = ExecutionCounter("useful_work_done") // this starts the counter from
|
||||
// the underlying trigger
|
||||
... call C++ run() ...
|
||||
counter.elapsed_value() // This returns the incremented value from the
|
||||
// trigger since the creation of the counter.
|
||||
*/
|
||||
|
||||
class ExecutionTrigger;
|
||||
class ExecutionTriggerList {
|
||||
public:
|
||||
TORCH_API static ExecutionTriggerList& GetInstance() {
|
||||
static ExecutionTriggerList instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
ExecutionTrigger* FindByName(const std::string& name) const {
|
||||
auto iter = trigger_list_.find(name);
|
||||
if (iter == trigger_list_.end()) {
|
||||
throw std::runtime_error("Invalid trigger name: " + name);
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class ExecutionTrigger;
|
||||
|
||||
ExecutionTriggerList() {}
|
||||
ExecutionTriggerList(const ExecutionTriggerList&) = delete;
|
||||
ExecutionTriggerList& operator=(const ExecutionTriggerList&) = delete;
|
||||
|
||||
void AddTrigger(const std::string& name, ExecutionTrigger* trigger) {
|
||||
auto insert_ret = trigger_list_.insert(std::make_pair(name, trigger));
|
||||
if (!insert_ret.second) {
|
||||
throw std::runtime_error("Duplicated trigger name: " + name);
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, ExecutionTrigger*> trigger_list_;
|
||||
};
|
||||
|
||||
class ExecutionTrigger {
|
||||
public:
|
||||
explicit ExecutionTrigger(const std::string& name) : name_(name) {
|
||||
ExecutionTriggerList::GetInstance().AddTrigger(name, this);
|
||||
}
|
||||
|
||||
int value() const {
|
||||
return value_;
|
||||
}
|
||||
|
||||
void trigger() {
|
||||
value_++;
|
||||
}
|
||||
|
||||
private:
|
||||
ExecutionTrigger(const ExecutionTrigger&) = delete;
|
||||
ExecutionTrigger& operator=(const ExecutionTrigger&) = delete;
|
||||
int value_ = 0;
|
||||
const std::string name_;
|
||||
};
|
||||
|
||||
class ExecutionCounter {
|
||||
public:
|
||||
explicit ExecutionCounter(ExecutionTrigger& trigger) : trigger_(trigger) {
|
||||
start_value_ = trigger_.value();
|
||||
}
|
||||
|
||||
int elapsed_value() const {
|
||||
return trigger_.value() - start_value_;
|
||||
}
|
||||
|
||||
private:
|
||||
ExecutionTrigger& trigger_;
|
||||
int start_value_ = 0;
|
||||
};
|
||||
|
||||
#define DEFINE_TRIGGER(name) ExecutionTrigger name(#name)
|
||||
#define DECLARE_TRIGGER(name) TORCH_API extern ExecutionTrigger name
|
||||
#define USE_TRIGGER(name) (name).trigger()
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
1180
torch/csrc/jit/tensorexpr/kernel.cpp
Normal file
1180
torch/csrc/jit/tensorexpr/kernel.cpp
Normal file
File diff suppressed because it is too large
Load Diff
210
torch/csrc/jit/tensorexpr/kernel.h
Normal file
210
torch/csrc/jit/tensorexpr/kernel.h
Normal file
@ -0,0 +1,210 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/tensorexpr/codegen.h>
|
||||
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace tensorexpr {
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<int64_t> bufferSizes(const T& t) {
|
||||
std::vector<int64_t> sizes;
|
||||
for (int i = 0; i < t->function()->ndim(); i++) {
|
||||
sizes.push_back(
|
||||
dynamic_cast<const IntImm*>(t->function()->dim(i))->value());
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<ExprHandle> computeIndicesToBroadcast(
|
||||
const std::vector<T>& output_axes,
|
||||
const std::vector<ExprHandle>& input_sizes) {
|
||||
TORCH_CHECK(
|
||||
output_axes.size() >= input_sizes.size(),
|
||||
"Cannot broadcast to a lower rank tensor");
|
||||
std::vector<ExprHandle> bcast;
|
||||
auto axis_it = output_axes.rbegin();
|
||||
auto size_it = input_sizes.rbegin();
|
||||
while (size_it != input_sizes.rend()) {
|
||||
auto const& size = size_it->AsNode<IntImm>();
|
||||
if (size && size->value() == 1) {
|
||||
bcast.push_back(0);
|
||||
} else {
|
||||
bcast.push_back(*axis_it);
|
||||
}
|
||||
++axis_it;
|
||||
++size_it;
|
||||
}
|
||||
std::reverse(bcast.begin(), bcast.end());
|
||||
return bcast;
|
||||
}
|
||||
|
||||
class TensorExprKernel {
|
||||
public:
|
||||
explicit TensorExprKernel(const Graph& subgraph);
|
||||
|
||||
void run(Stack& stack);
|
||||
|
||||
private:
|
||||
enum BackendType {
|
||||
kUninitialized,
|
||||
kSimpleIREval,
|
||||
};
|
||||
|
||||
ExprHandle constant(const torch::jit::Value* v);
|
||||
|
||||
template <typename T, typename T1>
|
||||
ExprHandle broadcast(const T& t, const std::vector<T1>& axes) {
|
||||
return t->call(computeIndicesToBroadcast(
|
||||
axes, ExprVectorToExprHandleVector(t->function()->dims())));
|
||||
}
|
||||
|
||||
template <typename T, typename T1>
|
||||
ExprHandle chunk(
|
||||
const T& t,
|
||||
size_t chunk_idx,
|
||||
size_t dim,
|
||||
size_t chunks,
|
||||
const std::vector<T1>& axes) {
|
||||
auto sizes = bufferSizes(t);
|
||||
size_t step = sizes[dim] / chunks;
|
||||
|
||||
std::vector<ExprHandle> indices;
|
||||
for (size_t i = 0; i < axes.size(); ++i) {
|
||||
if (i == dim) {
|
||||
indices.push_back(axes[i] + IntImm::make(chunk_idx * step));
|
||||
} else {
|
||||
indices.push_back(axes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return t->call(indices);
|
||||
}
|
||||
|
||||
std::vector<ExprHandle> valueShape(const torch::jit::Value* v);
|
||||
|
||||
void promoteInputs(std::vector<ExprHandle>& inputs);
|
||||
|
||||
ExprHandle demoteOutput(const ExprHandle& e, const torch::jit::Value* v);
|
||||
|
||||
template <typename T>
|
||||
ExprHandle tensorOrConstant(
|
||||
const torch::jit::Value* v,
|
||||
const std::vector<T>& axes) {
|
||||
auto ti = tensors_.find(v->unique());
|
||||
if (ti != tensors_.end()) {
|
||||
return broadcast(ti->second, axes);
|
||||
}
|
||||
return constant(v);
|
||||
}
|
||||
|
||||
Tensor* ComputeOneOperand(
|
||||
const std::string& name,
|
||||
const torch::jit::Value* v,
|
||||
const std::function<ExprHandle(const ExprHandle&)>& inner_expr);
|
||||
|
||||
Tensor* ComputeTwoOperand(
|
||||
const std::string& name,
|
||||
const torch::jit::Value* v,
|
||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
||||
inner_expr);
|
||||
|
||||
Tensor* ComputeTwoOperandWithAlpha(
|
||||
const std::string& name,
|
||||
const torch::jit::Value* v,
|
||||
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
|
||||
inner_expr);
|
||||
|
||||
Tensor* ComputeThreeOperand(
|
||||
const std::string& name,
|
||||
const torch::jit::Value* v,
|
||||
const std::function<
|
||||
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
|
||||
inner_expr);
|
||||
|
||||
Tensor* ComputeConditionWithTwoOperand(
|
||||
const std::string& name,
|
||||
const torch::jit::Value* v,
|
||||
const std::function<
|
||||
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
|
||||
inner_expr);
|
||||
|
||||
Tensor* ComputeFourOperand(
|
||||
const std::string& name,
|
||||
const torch::jit::Value* v,
|
||||
const std::function<ExprHandle(
|
||||
const ExprHandle&,
|
||||
const ExprHandle&,
|
||||
const ExprHandle&,
|
||||
const ExprHandle&)>& inner_expr);
|
||||
|
||||
Tensor* ComputeValue(const torch::jit::Value* v);
|
||||
|
||||
void LowerToBackend(BackendType backend_type);
|
||||
|
||||
void PickAndCheckBackendType(const at::ArrayRef<IValue>& inputs);
|
||||
|
||||
void CodeGenRun(const std::vector<CodeGen::CallArg>& run_args);
|
||||
|
||||
void bindInput(const torch::jit::Value* input);
|
||||
|
||||
ExprHandle createInputIndexExpr(
|
||||
const Buffer& buffer,
|
||||
const std::vector<VarHandle>& axes,
|
||||
const c10::VaryingShape& sizes,
|
||||
const c10::VaryingStrides& strides,
|
||||
const c10::VaryingStrides& contiguity,
|
||||
const std::unordered_map<int64_t, VarHandle>& sizeVars);
|
||||
|
||||
private:
|
||||
struct ShapeArg {
|
||||
size_t idx;
|
||||
VarHandle var;
|
||||
|
||||
ShapeArg(size_t i, VarHandle v) : idx(i), var(v) {}
|
||||
};
|
||||
|
||||
struct KernelArg {
|
||||
template <typename B>
|
||||
KernelArg(B&& b) : bufferArg_(std::forward<B>(b)) {}
|
||||
|
||||
template <typename B, typename T>
|
||||
KernelArg(B&& b, T&& sizes, T&& strides)
|
||||
: bufferArg_(b),
|
||||
sizeArgs_(std::forward<T>(sizes)),
|
||||
strideArgs_(std::forward<T>(strides)) {}
|
||||
|
||||
const CodeGen::BufferArg& buffer() const {
|
||||
return bufferArg_;
|
||||
}
|
||||
|
||||
const std::vector<ShapeArg>& sizes() const {
|
||||
return sizeArgs_;
|
||||
}
|
||||
|
||||
const std::vector<ShapeArg>& strides() const {
|
||||
return strideArgs_;
|
||||
}
|
||||
|
||||
CodeGen::BufferArg bufferArg_;
|
||||
std::vector<ShapeArg> sizeArgs_;
|
||||
std::vector<ShapeArg> strideArgs_;
|
||||
};
|
||||
|
||||
int64_t n_inputs_ = 0;
|
||||
std::vector<KernelArg> kernelArgs_;
|
||||
std::vector<Tensor*> tensor_outputs_;
|
||||
std::unordered_map<int64_t, Tensor*> tensors_;
|
||||
std::unordered_map<int64_t, VarHandle> scalars_;
|
||||
std::unique_ptr<CodeGen> codegen_;
|
||||
KernelArena kernel_arena_;
|
||||
BackendType backend_type_ = BackendType::kUninitialized;
|
||||
at::Device device_ = at::kCPU;
|
||||
};
|
||||
|
||||
} // namespace tensorexpr
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user