[TensorExpr] Add a fuser pass based on tensor expressions. (#34226)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34226

LLVM and Cuda backends are added in subsequent PRs, so at this point the fuser is pretty useless, but it still can be tested and its logic is not going to change with addition of the codegens.

Differential Revision: D20251838

Test Plan: Imported from OSS

Pulled By: ZolotukhinM

fbshipit-source-id: 82b0d221fa89904ed526689d02a6c7676a8ce8de
This commit is contained in:
Mikhail Zolotukhin
2020-03-16 11:38:29 -07:00
committed by Facebook GitHub Bot
parent e31d462e92
commit 42b2c8c65d
12 changed files with 2745 additions and 35 deletions

View File

@ -481,6 +481,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_printer.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/schedule.cpp
${TORCH_SRC_DIR}/csrc/jit/tensorexpr/tensor.cpp

View File

@ -524,7 +524,7 @@ def log_test_reason(file_type, filename, test, options):
print_to_stderr(
'Determination found {} file {} -- running {}'.format(
file_type,
filename,
filename,
test,
)
)
@ -604,7 +604,7 @@ def determine_target(test, touched_files, options):
if touched_module.startswith('test.'):
touched_module = touched_module.split('test.')[1]
if (
touched_module in dep_modules
touched_module in dep_modules
or touched_module == test.replace('/', '.')
):
log_test_reason(file_type, touched_file, test, options)
@ -636,13 +636,13 @@ def main():
if options.determine_from is not None and os.path.exists(options.determine_from):
with open(options.determine_from, 'r') as fh:
touched_files = [
os.path.normpath(name.strip()) for name in fh.read().split('\n')
os.path.normpath(name.strip()) for name in fh.read().split('\n')
if len(name.strip()) > 0
]
# HACK: Ensure the 'test' paths can be traversed by Modulefinder
sys.path.append('test')
selected_tests = [
test for test in selected_tests
test for test in selected_tests
if determine_target(test, touched_files, options)
]
sys.path.remove('test')

1047
test/test_tensorexpr.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -205,6 +205,7 @@ libtorch_sources = [
"torch/csrc/jit/tensorexpr/ir_mutator.cpp",
"torch/csrc/jit/tensorexpr/ir_printer.cpp",
"torch/csrc/jit/tensorexpr/ir_visitor.cpp",
"torch/csrc/jit/tensorexpr/kernel.cpp",
"torch/csrc/jit/tensorexpr/mem_arena.cpp",
"torch/csrc/jit/tensorexpr/schedule.cpp",
"torch/csrc/jit/tensorexpr/tensor.cpp",

View File

@ -1,17 +1,23 @@
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/runtime/operator_options.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/operator_options.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
namespace torch {
namespace jit {
static bool texpr_fuser_enabled = true;
void setTensorExprFuserEnabled(bool val) {
texpr_fuser_enabled = val;
}
const Symbol& getTensorExprSymbol() {
static Symbol s = Symbol::fromQualString("tensorexpr::Group");
return s;
@ -36,9 +42,88 @@ value_list sortReverseTopological(
return result;
}
bool isSupported(Node* node) {
// TODO:
switch (node->kind()) {
case aten::add:
case aten::_cast_Float:
case aten::type_as:
case aten::sub:
case aten::mul:
case aten::div:
case aten::eq:
case aten::ne:
case aten::ge:
case aten::gt:
case aten::le:
case aten::lt:
case aten::min:
case aten::max:
case aten::pow:
case aten::clamp:
case aten::lerp:
case aten::log10:
case aten::log:
case aten::log2:
case aten::exp:
case aten::erf:
case aten::erfc:
case aten::fmod:
case aten::cos:
case aten::sin:
case aten::tan:
case aten::acos:
case aten::asin:
case aten::atan:
case aten::atan2:
case aten::cosh:
case aten::sinh:
case aten::tanh:
case aten::sqrt:
case aten::rsqrt:
case aten::abs:
case aten::floor:
case aten::ceil:
case aten::round:
case aten::trunc:
case aten::threshold:
case aten::remainder:
case prim::ConstantChunk:
case aten::cat:
case prim::ListConstruct:
case aten::sigmoid:
case aten::relu:
case aten::addcmul:
case aten::neg:
case aten::reciprocal:
case aten::expm1:
case aten::lgamma:
case aten::slice:
case aten::unsqueeze:
case aten::frac:
case aten::rand_like:
case aten::_sigmoid_backward:
case aten::_tanh_backward:
case aten::__and__:
case aten::__or__:
case aten::__xor__:
case aten::__lshift__:
case aten::__rshift__:
case aten::where:
return true;
default:
return false;
}
}
bool canHandle(Node* node, AliasDb& aliasDb) {
// TODO: actually support some ops
return false;
if (node->kind() == prim::Constant) {
return true;
}
if (node->kind() == prim::Loop) {
return false; // TODO
}
return isSupported(node);
}
#define REQ(cond) \
@ -65,6 +150,36 @@ bool canMerge(Node* consumer, Node* producer, AliasDb& aliasDb) {
// Alias checks
REQ(aliasDb.couldMoveAfterTopologically(consumer, producer));
// Ops that return aliases can only be folded if this is the only use.
if (producer->kind() == aten::slice || producer->kind() == aten::unsqueeze ||
producer->kind() == prim::ConstantChunk) {
for (auto& use : producer->output(0)->uses()) {
REQ(use.user == consumer);
}
}
if (!consumer->hasAttribute(attr::Subgraph) &&
consumer->kind() != getTensorExprSymbol()) {
// Don't initiate a fusion group from prim::ListConstruct
REQ(consumer->kind() != prim::ListConstruct);
REQ(consumer->kind() != aten::slice);
REQ(consumer->kind() != aten::unsqueeze);
REQ(consumer->kind() != prim::ConstantChunk);
// Don't initiate a fusion group just for a constant operand
REQ(producer->kind() != prim::Constant);
}
if (producer->kind() == aten::cat) {
REQ(producer->inputs()[0]->node()->kind() == prim::ListConstruct);
REQ(producer->inputs()[0]->uses().size() == 1);
REQ(producer->inputs()[1]->node()->kind() == prim::Constant);
} else if (consumer->kind() == aten::cat) {
REQ(consumer->inputs()[0]->node()->kind() == prim::ListConstruct);
REQ(consumer->inputs()[0]->uses().size() == 1);
REQ(consumer->inputs()[1]->node()->kind() == prim::Constant);
}
return true;
}
#undef REQ
@ -73,7 +188,10 @@ Node* getOrCreateTensorExprSubgraph(Node* n) {
if (n->hasAttribute(attr::Subgraph) && n->kind() == getTensorExprSymbol()) {
return n;
}
return SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol());
auto te_group =
SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol());
GRAPH_UPDATE("getOrCreateTensorExprSubgraph: ", *te_group);
return te_group;
}
c10::optional<Node*> tryMerge(
@ -82,9 +200,9 @@ c10::optional<Node*> tryMerge(
AliasDb& aliasDb) {
GRAPH_DEBUG(
"Trying producer ",
producer->kind().toQualString(),
getHeader(producer),
" and consumer ",
consumer->kind().toQualString(),
getHeader(consumer),
":\n");
if (!canMerge(consumer, producer, aliasDb)) {
@ -93,8 +211,24 @@ c10::optional<Node*> tryMerge(
consumer = getOrCreateTensorExprSubgraph(consumer);
aliasDb.moveAfterTopologicallyValid(consumer, producer);
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
if (producer->kind() == aten::cat) {
Node* listconstruct = producer->inputs()[0]->node();
aliasDb.moveAfterTopologicallyValid(consumer, producer);
GRAPH_UPDATE(
"Merging ", getHeader(producer), " into ", getHeader(consumer));
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
aliasDb.moveAfterTopologicallyValid(consumer, listconstruct);
GRAPH_UPDATE(
"Merging ", getHeader(listconstruct), " into ", getHeader(consumer));
SubgraphUtils::mergeNodeIntoSubgraph(listconstruct, consumer);
} else {
aliasDb.moveAfterTopologicallyValid(consumer, producer);
GRAPH_UPDATE(
"Merging ", getHeader(producer), " into ", getHeader(consumer));
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
}
return consumer;
}
@ -120,26 +254,10 @@ std::pair<graph_node_list::iterator, bool> scanNode(
return {++(++iter), false};
}
Operation createTensorExprOp(const Node* node) {
// TODO: actually compile the fusion group.
return [](Stack& stack) {
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
return 0;
};
}
c10::AliasAnalysisKind getAliasAnalysisOption(AliasAnalysisKind k) {
return k;
}
RegisterOperators TensorExprOps({
torch::jit::Operator(
getTensorExprSymbol(),
createTensorExprOp,
getAliasAnalysisOption(AliasAnalysisKind::PURE_FUNCTION)),
});
void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
if (!texpr_fuser_enabled) {
return;
}
GRAPH_DUMP("Before TExprFuser: ", graph);
// Get rid of dead code so that we don't waste effort fusing it.
@ -192,6 +310,23 @@ void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("After TExprFuser: ", graph);
}
Operation createTensorExprOp(const Node* node) {
auto kernel =
std::make_shared<tensorexpr::TensorExprKernel>(*node->g(attr::Subgraph));
return [kernel](Stack& stack) {
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
kernel->run(stack);
return 0;
};
}
RegisterOperators TensorExprOps({
torch::jit::Operator(
getTensorExprSymbol(),
createTensorExprOp,
AliasAnalysisKind::PURE_FUNCTION),
});
void registerTensorExprFuser() {
static bool already_registered = false;
if (!already_registered) {

View File

@ -14,5 +14,7 @@ TORCH_API void fuseTensorExprs(std::shared_ptr<Graph>& graph);
// Register TensorExpressions-based fuser in custom passes.
TORCH_API void registerTensorExprFuser();
TORCH_API void setTensorExprFuserEnabled(bool val);
} // namespace jit
} // namespace torch

View File

@ -59,6 +59,7 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/python/python_tree_views.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/tensorexpr/execution_counter.h>
#include <c10/macros/Export.h>
#include <caffe2/serialize/inline_container.h>
@ -405,6 +406,15 @@ void initJITBindings(PyObject* module) {
}
return nullptr;
})
.def(
"_jit_get_trigger_value",
[](const std::string& trigger_name) {
using namespace torch::jit::tensorexpr;
ExecutionTrigger* trigger =
ExecutionTriggerList::GetInstance().FindByName(trigger_name);
return trigger->value();
})
.def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
.def(
"_jit_fuser_get_fused_kernel_code",
[](Graph& g, std::vector<at::Tensor> inps) {

View File

@ -4,7 +4,9 @@ namespace torch {
namespace jit {
namespace tensorexpr {
RegisterCodeGen<SimpleIREvaluator> reg("simple_ir_eval");
DEFINE_TRIGGER(simple_ir_eval_executed);
RegisterCodeGen<SimpleIREvaluator> ir_eval_codegen_reg("simple_ir_eval");
} // namespace tensorexpr
} // namespace jit

View File

@ -7,6 +7,7 @@
#include <c10/util/Logging.h>
#include "torch/csrc/jit/tensorexpr/buffer.h"
#include "torch/csrc/jit/tensorexpr/codegen.h"
#include "torch/csrc/jit/tensorexpr/execution_counter.h"
#include "torch/csrc/jit/tensorexpr/function.h"
#include "torch/csrc/jit/tensorexpr/ir.h"
#include "torch/csrc/jit/tensorexpr/ir_printer.h"
@ -17,6 +18,8 @@ namespace torch {
namespace jit {
namespace tensorexpr {
DECLARE_TRIGGER(simple_ir_eval_executed);
class Value {
public:
Value() : dtype_(kInt) {
@ -106,6 +109,7 @@ class SimpleIREvaluator : public CodeGen, public IRVisitor {
eval_context_.clear();
buffer_mapping_.clear();
internal_buffers_.clear();
USE_TRIGGER(simple_ir_eval_executed);
}
void bind(const BufferArg& buf, const CallArg& data) {

View File

@ -0,0 +1,118 @@
#pragma once
#include "torch/csrc/WindowsTorchApiMacro.h"
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
namespace tensorexpr {
/*
ExecutionTrigger and ExecutionCounter builds instrumentation counters so
underlying functionalities can be checked.
In the code to be instrumented:
// worker.cpp
DEFINE_TRIGGER(useful_work_done); // this defines a trigger "useful_work_done"
void run() {
USE_TRIGGER(useful_work_done); // this triggers the underlying counter
// in "useful_work_done"
}
// in C++ client.cpp
DECLARE_TRIGGER(useful_work_done); // Optional: this declares a trigger that
// will be defined elsewhere
ExecutionCounter counter(useful_work_done); // This starts the counter from the
// underlying trigger.
... call run() ...
counter.elapsed_value(); // this returns the incremented value from the
// trigger since the creation of the counter
// in Python client.py
counter = ExecutionCounter("useful_work_done") // this starts the counter from
// the underlying trigger
... call C++ run() ...
counter.elapsed_value() // This returns the incremented value from the
// trigger since the creation of the counter.
*/
class ExecutionTrigger;
class ExecutionTriggerList {
public:
TORCH_API static ExecutionTriggerList& GetInstance() {
static ExecutionTriggerList instance;
return instance;
}
ExecutionTrigger* FindByName(const std::string& name) const {
auto iter = trigger_list_.find(name);
if (iter == trigger_list_.end()) {
throw std::runtime_error("Invalid trigger name: " + name);
}
return iter->second;
}
private:
friend class ExecutionTrigger;
ExecutionTriggerList() {}
ExecutionTriggerList(const ExecutionTriggerList&) = delete;
ExecutionTriggerList& operator=(const ExecutionTriggerList&) = delete;
void AddTrigger(const std::string& name, ExecutionTrigger* trigger) {
auto insert_ret = trigger_list_.insert(std::make_pair(name, trigger));
if (!insert_ret.second) {
throw std::runtime_error("Duplicated trigger name: " + name);
}
}
std::unordered_map<std::string, ExecutionTrigger*> trigger_list_;
};
class ExecutionTrigger {
public:
explicit ExecutionTrigger(const std::string& name) : name_(name) {
ExecutionTriggerList::GetInstance().AddTrigger(name, this);
}
int value() const {
return value_;
}
void trigger() {
value_++;
}
private:
ExecutionTrigger(const ExecutionTrigger&) = delete;
ExecutionTrigger& operator=(const ExecutionTrigger&) = delete;
int value_ = 0;
const std::string name_;
};
class ExecutionCounter {
public:
explicit ExecutionCounter(ExecutionTrigger& trigger) : trigger_(trigger) {
start_value_ = trigger_.value();
}
int elapsed_value() const {
return trigger_.value() - start_value_;
}
private:
ExecutionTrigger& trigger_;
int start_value_ = 0;
};
#define DEFINE_TRIGGER(name) ExecutionTrigger name(#name)
#define DECLARE_TRIGGER(name) TORCH_API extern ExecutionTrigger name
#define USE_TRIGGER(name) (name).trigger()
} // namespace tensorexpr
} // namespace jit
} // namespace torch

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,210 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
namespace torch {
namespace jit {
namespace tensorexpr {
template <typename T>
inline std::vector<int64_t> bufferSizes(const T& t) {
std::vector<int64_t> sizes;
for (int i = 0; i < t->function()->ndim(); i++) {
sizes.push_back(
dynamic_cast<const IntImm*>(t->function()->dim(i))->value());
}
return sizes;
}
template <typename T>
inline std::vector<ExprHandle> computeIndicesToBroadcast(
const std::vector<T>& output_axes,
const std::vector<ExprHandle>& input_sizes) {
TORCH_CHECK(
output_axes.size() >= input_sizes.size(),
"Cannot broadcast to a lower rank tensor");
std::vector<ExprHandle> bcast;
auto axis_it = output_axes.rbegin();
auto size_it = input_sizes.rbegin();
while (size_it != input_sizes.rend()) {
auto const& size = size_it->AsNode<IntImm>();
if (size && size->value() == 1) {
bcast.push_back(0);
} else {
bcast.push_back(*axis_it);
}
++axis_it;
++size_it;
}
std::reverse(bcast.begin(), bcast.end());
return bcast;
}
class TensorExprKernel {
public:
explicit TensorExprKernel(const Graph& subgraph);
void run(Stack& stack);
private:
enum BackendType {
kUninitialized,
kSimpleIREval,
};
ExprHandle constant(const torch::jit::Value* v);
template <typename T, typename T1>
ExprHandle broadcast(const T& t, const std::vector<T1>& axes) {
return t->call(computeIndicesToBroadcast(
axes, ExprVectorToExprHandleVector(t->function()->dims())));
}
template <typename T, typename T1>
ExprHandle chunk(
const T& t,
size_t chunk_idx,
size_t dim,
size_t chunks,
const std::vector<T1>& axes) {
auto sizes = bufferSizes(t);
size_t step = sizes[dim] / chunks;
std::vector<ExprHandle> indices;
for (size_t i = 0; i < axes.size(); ++i) {
if (i == dim) {
indices.push_back(axes[i] + IntImm::make(chunk_idx * step));
} else {
indices.push_back(axes[i]);
}
}
return t->call(indices);
}
std::vector<ExprHandle> valueShape(const torch::jit::Value* v);
void promoteInputs(std::vector<ExprHandle>& inputs);
ExprHandle demoteOutput(const ExprHandle& e, const torch::jit::Value* v);
template <typename T>
ExprHandle tensorOrConstant(
const torch::jit::Value* v,
const std::vector<T>& axes) {
auto ti = tensors_.find(v->unique());
if (ti != tensors_.end()) {
return broadcast(ti->second, axes);
}
return constant(v);
}
Tensor* ComputeOneOperand(
const std::string& name,
const torch::jit::Value* v,
const std::function<ExprHandle(const ExprHandle&)>& inner_expr);
Tensor* ComputeTwoOperand(
const std::string& name,
const torch::jit::Value* v,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
inner_expr);
Tensor* ComputeTwoOperandWithAlpha(
const std::string& name,
const torch::jit::Value* v,
const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
inner_expr);
Tensor* ComputeThreeOperand(
const std::string& name,
const torch::jit::Value* v,
const std::function<
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
inner_expr);
Tensor* ComputeConditionWithTwoOperand(
const std::string& name,
const torch::jit::Value* v,
const std::function<
ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
inner_expr);
Tensor* ComputeFourOperand(
const std::string& name,
const torch::jit::Value* v,
const std::function<ExprHandle(
const ExprHandle&,
const ExprHandle&,
const ExprHandle&,
const ExprHandle&)>& inner_expr);
Tensor* ComputeValue(const torch::jit::Value* v);
void LowerToBackend(BackendType backend_type);
void PickAndCheckBackendType(const at::ArrayRef<IValue>& inputs);
void CodeGenRun(const std::vector<CodeGen::CallArg>& run_args);
void bindInput(const torch::jit::Value* input);
ExprHandle createInputIndexExpr(
const Buffer& buffer,
const std::vector<VarHandle>& axes,
const c10::VaryingShape& sizes,
const c10::VaryingStrides& strides,
const c10::VaryingStrides& contiguity,
const std::unordered_map<int64_t, VarHandle>& sizeVars);
private:
struct ShapeArg {
size_t idx;
VarHandle var;
ShapeArg(size_t i, VarHandle v) : idx(i), var(v) {}
};
struct KernelArg {
template <typename B>
KernelArg(B&& b) : bufferArg_(std::forward<B>(b)) {}
template <typename B, typename T>
KernelArg(B&& b, T&& sizes, T&& strides)
: bufferArg_(b),
sizeArgs_(std::forward<T>(sizes)),
strideArgs_(std::forward<T>(strides)) {}
const CodeGen::BufferArg& buffer() const {
return bufferArg_;
}
const std::vector<ShapeArg>& sizes() const {
return sizeArgs_;
}
const std::vector<ShapeArg>& strides() const {
return strideArgs_;
}
CodeGen::BufferArg bufferArg_;
std::vector<ShapeArg> sizeArgs_;
std::vector<ShapeArg> strideArgs_;
};
int64_t n_inputs_ = 0;
std::vector<KernelArg> kernelArgs_;
std::vector<Tensor*> tensor_outputs_;
std::unordered_map<int64_t, Tensor*> tensors_;
std::unordered_map<int64_t, VarHandle> scalars_;
std::unique_ptr<CodeGen> codegen_;
KernelArena kernel_arena_;
BackendType backend_type_ = BackendType::kUninitialized;
at::Device device_ = at::kCPU;
};
} // namespace tensorexpr
} // namespace jit
} // namespace torch