mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit bd72e28314d8d63bb347becb8309f5ac7761c6b5.
Reverted https://github.com/pytorch/pytorch/pull/128301 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it fails XLA build bd72e28314
. Please rebase your PR before relanding because I think the failure is hidden by an unrelated broken trunk XLA failure from your current base commit ([comment](https://github.com/pytorch/pytorch/pull/128301#issuecomment-2169035822))
2136 lines
71 KiB
C++
2136 lines
71 KiB
C++
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
|
|
|
#include <ATen/ExpandUtils.h>
|
|
#include <ATen/Parallel.h>
|
|
#include <ATen/TensorGeometry.h>
|
|
#include <c10/core/ScalarTypeToTypeMeta.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
|
#include <torch/csrc/jit/passes/mkldnn_rewrite.h>
|
|
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
|
|
#include <torch/csrc/jit/tensorexpr/analysis.h>
|
|
#include <torch/csrc/jit/tensorexpr/expr.h>
|
|
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
|
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest_randomization.h>
|
|
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
|
|
|
|
using namespace torch::jit;
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
namespace torch::jit::tensorexpr {
|
|
|
|
std::string buildErrorMessage(const std::string& s) {
|
|
static const std::string generic_error_message =
|
|
"This error occurred in the fuser. You can turn off the fuser with "
|
|
"torch.jit.enable_fusion(False).";
|
|
if (s.empty()) {
|
|
return generic_error_message;
|
|
}
|
|
if (s.back() == '.') {
|
|
return s + " " + generic_error_message;
|
|
}
|
|
return s + ". " + generic_error_message;
|
|
}
|
|
|
|
static int te_cuda_pointwise_loop_levels = -1;
|
|
static int te_cuda_pointwise_block_count = -1;
|
|
static int te_cuda_pointwise_block_size = -1;
|
|
static bool fallback_allowed = false;
|
|
static bool te_generate_block_code = false;
|
|
static bool te_must_use_llvm_on_cpu = true;
|
|
static bool cat_wo_conditionals = true; // NOLINT
|
|
static bool opt_conditionals = false; // NOLINT
|
|
|
|
bool setFallbackAllowed(bool value) {
|
|
bool old_value = fallback_allowed;
|
|
fallback_allowed = value;
|
|
return old_value;
|
|
}
|
|
|
|
bool fallbackAllowed() {
|
|
static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK");
|
|
if (!enable_c_str) {
|
|
return fallback_allowed;
|
|
}
|
|
if (std::string(enable_c_str) == "0") {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static bool fallbackEnforced() {
|
|
static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK");
|
|
if (tensorexpr::getTEGenerateBlockCode()) {
|
|
return false;
|
|
}
|
|
if (!enable_c_str) {
|
|
return fallback_allowed;
|
|
}
|
|
if (std::string(enable_c_str) == "2") {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static int64_t randomTransformsRequested() {
|
|
const char* enable_c_str =
|
|
std::getenv("PYTORCH_TENSOREXPR_RANDOM_TRANSFORM_SEED");
|
|
if (!enable_c_str) {
|
|
return 0;
|
|
}
|
|
return std::stoi(std::string(enable_c_str));
|
|
}
|
|
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
static bool dontUseLLVMFlag() {
|
|
static const char* enable_c_str =
|
|
std::getenv("PYTORCH_TENSOREXPR_DONT_USE_LLVM");
|
|
if (!enable_c_str) {
|
|
return false;
|
|
}
|
|
return std::string(enable_c_str) == "1";
|
|
}
|
|
#endif
|
|
|
|
int& getTECudaPointwiseLoopLevels() {
|
|
return te_cuda_pointwise_loop_levels;
|
|
}
|
|
|
|
int& getTECudaPointwiseBlockCount() {
|
|
return te_cuda_pointwise_block_count;
|
|
}
|
|
|
|
int& getTECudaPointwiseBlockSize() {
|
|
return te_cuda_pointwise_block_size;
|
|
}
|
|
|
|
// TODO: Remove this global var
|
|
// Ideally Block code gen should be decided
|
|
// based on device type in tensor.
|
|
bool& getTEGenerateBlockCode() {
|
|
return te_generate_block_code;
|
|
}
|
|
|
|
bool& getTEMustUseLLVMOnCPU() {
|
|
return te_must_use_llvm_on_cpu;
|
|
}
|
|
|
|
bool& getCatWoConditionals() {
|
|
return cat_wo_conditionals;
|
|
}
|
|
|
|
bool& getOptConditionals() {
|
|
return opt_conditionals;
|
|
}
|
|
|
|
std::optional<at::Device> pickDeviceType(
|
|
const at::ArrayRef<torch::jit::Value*>& inputs) {
|
|
std::optional<at::Device> device = c10::nullopt;
|
|
for (auto const& input : inputs) {
|
|
auto tt = input->type()->cast<TensorType>();
|
|
if (tt && tt->device()) {
|
|
if (device && *device != *tt->device()) {
|
|
return c10::nullopt;
|
|
}
|
|
device = *tt->device();
|
|
}
|
|
}
|
|
return device;
|
|
}
|
|
|
|
static std::optional<at::Device> pickDeviceType(
|
|
const std::shared_ptr<Graph>& graph) {
|
|
std::optional<at::Device> device = c10::nullopt;
|
|
for (auto const& node : graph->nodes()) {
|
|
for (auto const& input : node->inputs()) {
|
|
if (auto tt = input->type()->cast<TensorType>()) {
|
|
if (auto inputDevice = tt->device()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!device || *device == *inputDevice,
|
|
buildErrorMessage(
|
|
"Different devices specified for inputs to the fuser."));
|
|
device = inputDevice;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
for (auto const& input : graph->inputs()) {
|
|
if (auto tt = input->type()->cast<TensorType>()) {
|
|
if (auto inputDevice = tt->device()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!device || *device == *inputDevice,
|
|
buildErrorMessage(
|
|
"Different devices specified for inputs to the fuser."));
|
|
device = inputDevice;
|
|
}
|
|
}
|
|
}
|
|
if (!device) {
|
|
// By default assume the device is CPU
|
|
device = at::kCPU;
|
|
}
|
|
return device;
|
|
}
|
|
|
|
// If v is a Tensor with concretely-known sizes and dtype, return them, else
|
|
// nullopt.
|
|
static std::optional<TensorInfo> getTensorInfoJit(torch::jit::Value* v) {
|
|
auto const& it = v->type()->cast<TensorType>();
|
|
|
|
c10::ScalarType dtype = c10::ScalarType::Float;
|
|
|
|
if (!it) {
|
|
return c10::nullopt;
|
|
}
|
|
if (!it->isComplete()) {
|
|
return c10::nullopt;
|
|
}
|
|
if (it->scalarType()) {
|
|
// TODO: ideally we should be strict here and return nullopt if the dtype is
|
|
// absent in the JIT IR. We're assuming a default Float dtype for now, until
|
|
// dtype propagation is implemented.
|
|
dtype = *it->scalarType();
|
|
}
|
|
auto concrete_sizes = it->sizes().concrete_sizes();
|
|
if (!concrete_sizes) {
|
|
return c10::nullopt;
|
|
}
|
|
return TensorInfo{*concrete_sizes, dtype};
|
|
}
|
|
static std::vector<int64_t> _pair_int(IValue v) {
|
|
if (v.isIntList()) {
|
|
return v.toIntVector();
|
|
} else {
|
|
return {v.toInt(), v.toInt()};
|
|
}
|
|
}
|
|
|
|
bool isContiguous(const torch::jit::Value* v, at::MemoryFormat memory_format) {
|
|
auto const& tt = v->type()->cast<TensorType>();
|
|
if (!tt) {
|
|
return false;
|
|
}
|
|
if (!tt->isComplete()) {
|
|
return false;
|
|
}
|
|
auto const& sizes = tt->sizes().concrete_sizes();
|
|
auto const& strides = tt->strides().concrete_sizes();
|
|
if (!sizes || !strides) {
|
|
return false;
|
|
}
|
|
|
|
// Check dimension size first
|
|
int ndims = (*sizes).size();
|
|
if ((memory_format == at::MemoryFormat::ChannelsLast && ndims != 4) ||
|
|
(memory_format == at::MemoryFormat::ChannelsLast3d && ndims != 5)) {
|
|
return false;
|
|
}
|
|
|
|
return *strides == TensorType::contiguousStridesOf(*sizes, memory_format);
|
|
}
|
|
|
|
static size_t get_conv_groups_index(const torch::jit::Node* node) {
|
|
switch (node->kind()) {
|
|
case aten::conv2d:
|
|
return 6;
|
|
case aten::_convolution:
|
|
return 8;
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
"mkldnnPrepackedConvIsSupportedJit expects node kind to be conv2d or _convolution but got ",
|
|
node->kind());
|
|
}
|
|
}
|
|
|
|
// The fuser only supports conv2d with very specific properties:
|
|
// - Static shapes: 4-d input and filter, 1-d bias.
|
|
// - Constant strides/padding/dilation/groups
|
|
// - Equal padding and strides, dilation == 1.
|
|
// - Depthwise (groups == in_channels == out_channels)
|
|
// - 3x3 kernel
|
|
bool conv2dIsSupportedJit(const torch::jit::Node* node) {
|
|
auto const& input = getTensorInfoJit(node->input(0));
|
|
auto const& weight = getTensorInfoJit(node->input(1));
|
|
auto const& bias = getTensorInfoJit(node->input(2));
|
|
auto const& stride = toIValue(node->input(3));
|
|
auto const& pad = toIValue(node->input(4));
|
|
auto const& dilation = toIValue(node->input(5));
|
|
size_t groups_index = get_conv_groups_index(node);
|
|
auto const& groups = toIValue(node->input(groups_index));
|
|
|
|
// Everything should be statically known.
|
|
if (!input || !weight || !bias || !stride || !pad || !dilation || !groups) {
|
|
GRAPH_DEBUG("some params aren't static");
|
|
return false;
|
|
}
|
|
|
|
// All inputs should be contiguous so no transposition is required.
|
|
if (!isContiguous(node->input(0)) || !isContiguous(node->input(1)) ||
|
|
!isContiguous(node->input(2))) {
|
|
GRAPH_DEBUG("conv2dIsSupported: some inputs are not contiguous");
|
|
return false;
|
|
}
|
|
|
|
return conv2dIsSupported(
|
|
*input,
|
|
*weight,
|
|
*bias,
|
|
_pair_int(*stride),
|
|
_pair_int(*pad),
|
|
_pair_int(*dilation),
|
|
groups->toInt());
|
|
}
|
|
|
|
bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) {
|
|
#if AT_MKLDNN_ENABLED()
|
|
auto const& input = getTensorInfoJit(node->input(0));
|
|
auto const& weight = getTensorInfoJit(node->input(1));
|
|
auto const& stride = toIValue(node->input(3));
|
|
auto const& pad = toIValue(node->input(4));
|
|
auto const& dilation = toIValue(node->input(5));
|
|
size_t groups_index = get_conv_groups_index(node);
|
|
auto const& groups = toIValue(node->input(groups_index));
|
|
|
|
// Everything should be statically known (bias could be NoneType =
|
|
// prim::Constant()).
|
|
if (!input || !weight || !stride || !pad || !dilation || !groups) {
|
|
GRAPH_DEBUG("some params aren't static");
|
|
return false;
|
|
}
|
|
|
|
// Weights and bias should be Constant when using mkldnn backend
|
|
if (node->input(1)->node()->kind() != prim::Constant ||
|
|
node->input(2)->node()->kind() != prim::Constant) {
|
|
GRAPH_DEBUG(
|
|
"mkldnnPrepackedConvIsSupported: weight or bias is not Constant");
|
|
return false;
|
|
}
|
|
|
|
// Input and weight should be NHWC contiguous.
|
|
if (!(isContiguous(node->input(0), at::MemoryFormat::ChannelsLast) &&
|
|
isContiguous(node->input(1), at::MemoryFormat::ChannelsLast))) {
|
|
GRAPH_DEBUG(
|
|
"mkldnnPrepackedConvIsSupported: input or weight is not ChannelsLast contiguous");
|
|
return false;
|
|
}
|
|
|
|
return mkldnnPrepackedConvIsSupported(
|
|
*input,
|
|
*weight,
|
|
_pair_int(*stride),
|
|
_pair_int(*pad),
|
|
_pair_int(*dilation),
|
|
groups->toInt());
|
|
#endif
|
|
return false;
|
|
}
|
|
|
|
bool isConv2d(const Node* node) {
|
|
if (node->kind() != aten::_convolution) {
|
|
return false;
|
|
}
|
|
|
|
auto const& stride = toIValue(node->input(3));
|
|
auto const& pad = toIValue(node->input(4));
|
|
auto const& dilation = toIValue(node->input(5));
|
|
auto const& transposed = toIValue(node->input(6));
|
|
auto const& output_padding = toIValue(node->input(7));
|
|
|
|
if (!stride || !pad || !dilation || !transposed || !output_padding) {
|
|
GRAPH_DEBUG("some params aren't static");
|
|
return false;
|
|
}
|
|
|
|
if (stride.value().toIntList().size() != 2 ||
|
|
pad.value().toIntList().size() != 2 ||
|
|
dilation.value().toIntList().size() != 2 ||
|
|
output_padding.value().toIntList().size() != 2) {
|
|
GRAPH_DEBUG("Conv not 2d");
|
|
return false;
|
|
}
|
|
|
|
if (transposed.value().toBool()) {
|
|
GRAPH_DEBUG("transposed Conv");
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// The fuser currently only supports matmul of 2D x 2D matrices
|
|
bool matmulIsSupported(const torch::jit::Node* node) {
|
|
auto const& input0 = getTensorInfoJit(node->input(0));
|
|
auto const& input1 = getTensorInfoJit(node->input(1));
|
|
|
|
// Everything should be statically known.
|
|
if (!input0 || !input1) {
|
|
GRAPH_DEBUG("matmulIsSupported: Input shapes aren't static");
|
|
return false;
|
|
}
|
|
|
|
// Proper ndim for tensor inputs.
|
|
if (input0->dims.size() != 2 || input1->dims.size() != 2) {
|
|
GRAPH_DEBUG("matmulIsSupported: Unsupported input sizes");
|
|
return false;
|
|
}
|
|
|
|
// Inputs should be contiguous, or the TE will needlessly transpose them.
|
|
if (!isContiguous(node->input(0)) || !isContiguous(node->input(1))) {
|
|
GRAPH_DEBUG("matmulIsSupported: Input shapes are not contiguous");
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace torch::jit::tensorexpr
|
|
|
|
static at::ScalarType tensorType(BufPtr b) {
|
|
return static_cast<at::ScalarType>(b->dtype().scalar_type());
|
|
}
|
|
|
|
ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) {
|
|
if (v->node()->kind() == prim::Constant) {
|
|
auto val = toIValue(v).value();
|
|
if (val.isDouble()) {
|
|
return DoubleImm::make(val.toDouble());
|
|
} else if (val.isInt()) {
|
|
return LongImm::make(val.toInt());
|
|
} else if (val.isBool()) {
|
|
return BoolImm::make(val.toBool());
|
|
} else if (val.isNone()) {
|
|
// This is just a placeholder so we don't throw. None-handling
|
|
// is operator-specific and should be handled properly in
|
|
// the operator-specific lowering code.
|
|
return IntImm::make(0);
|
|
} else {
|
|
throw unsupported_dtype();
|
|
}
|
|
}
|
|
|
|
if (!scalars_.count(v)) {
|
|
throw malformed_input("no scalar in Constant");
|
|
}
|
|
|
|
return scalars_.at(v);
|
|
}
|
|
|
|
ArgValue TensorExprKernel::toArg(const torch::jit::Value* v) const {
|
|
auto vi = scalars_.find(v);
|
|
if (vi != scalars_.end()) {
|
|
return VarHandle(vi->second);
|
|
}
|
|
auto ti = bufs_.find(v);
|
|
if (ti != bufs_.end()) {
|
|
return BufHandle(ti->second);
|
|
}
|
|
if (v->node()->kind() == prim::ListConstruct) {
|
|
std::vector<ArgValue> vec;
|
|
for (auto el : v->node()->inputs()) {
|
|
vec.push_back(toArg(el));
|
|
}
|
|
if (vec.empty()) {
|
|
return BufList(); // Return arbitrarily typed vector
|
|
} else if (std::get_if<BufHandle>(&vec[0])) {
|
|
return convertVecArgValue<BufHandle>(vec);
|
|
} else if (std::get_if<int64_t>(&vec[0])) {
|
|
return convertVecArgValue<int64_t>(vec);
|
|
}
|
|
throw unsupported_dtype();
|
|
}
|
|
if (v->node()->kind() == prim::Constant) {
|
|
auto val = toIValue(v).value();
|
|
if (val.isDouble()) {
|
|
return val.toDouble();
|
|
} else if (val.isInt()) {
|
|
return val.toInt();
|
|
} else if (val.isBool()) {
|
|
return val.toBool();
|
|
} else if (val.isNone()) {
|
|
// This is just a placeholder so we don't throw. None-handling
|
|
// is operator-specific and should be handled properly in
|
|
// the operator-specific lowering code.
|
|
return ArgNone();
|
|
} else if (val.isIntList()) {
|
|
return val.toIntVector();
|
|
} else if (val.isDoubleList()) {
|
|
return val.toDoubleVector();
|
|
} else if (val.isString()) {
|
|
return val.toStringRef();
|
|
} else {
|
|
throw unsupported_dtype(val.type()->str());
|
|
}
|
|
}
|
|
|
|
if (!scalars_.count(v)) {
|
|
throw malformed_input("no scalar in Constant");
|
|
}
|
|
return scalars_.at(v);
|
|
}
|
|
|
|
ExprHandle TensorExprKernel::getVarForShape(const c10::ShapeSymbol& ss) {
|
|
if (ss.is_static()) {
|
|
return LongImm::make(ss.static_size());
|
|
}
|
|
auto value = ss.value();
|
|
auto it = shapeSymbolToVar_.find(value);
|
|
if (it == shapeSymbolToVar_.end()) {
|
|
VarHandle var("ss" + std::to_string(-value), kLong);
|
|
shapeSymbolToVar_.emplace(value, var);
|
|
return std::move(var);
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
std::vector<ExprHandle> TensorExprKernel::sizesFromSymbolicShape(
|
|
const c10::SymbolicShape& shape) {
|
|
std::vector<ExprHandle> dims;
|
|
auto maybe_rank = shape.rank();
|
|
TORCH_INTERNAL_ASSERT(maybe_rank);
|
|
auto rank = *maybe_rank;
|
|
for (const auto i : c10::irange(rank)) {
|
|
dims.push_back(getVarForShape(shape[i]));
|
|
}
|
|
return dims;
|
|
}
|
|
|
|
std::vector<ExprHandle> TensorExprKernel::sizesForValue(
|
|
const torch::jit::Value* v) {
|
|
if (known_sizes_.count(v)) {
|
|
return known_sizes_.at(v);
|
|
}
|
|
|
|
// If the shape is present in the type info, just extract it from here. No
|
|
// need to infer it.
|
|
if (v->type()->kind() == TypeKind::TensorType) {
|
|
auto tt = v->type()->cast<TensorType>();
|
|
return sizesFromSymbolicShape(tt->symbolic_sizes());
|
|
}
|
|
|
|
if (v->type()->isSubtypeOf(*FloatType::get()) ||
|
|
v->type()->isSubtypeOf(*BoolType::get()) ||
|
|
v->type()->isSubtypeOf(*IntType::get())) {
|
|
return {};
|
|
}
|
|
if (v->type()->isSubtypeOf(*NoneType::get())) {
|
|
return {};
|
|
}
|
|
GRAPH_DEBUG("Unknown sizes for the node: ", *v->node());
|
|
GRAPH_DEBUG("Full fusion group graph:\n", *v->node()->owningGraph());
|
|
std::string msg = std::string("Unhandled node kind (in sizesForValue): ") +
|
|
v->node()->kind().toQualString();
|
|
throw malformed_input(msg);
|
|
}
|
|
|
|
static std::optional<ScalarType> findDtypeForValue(const torch::jit::Value* v) {
|
|
if (v->type()->kind() == TypeKind::TensorType) {
|
|
auto tt = v->type()->cast<TensorType>();
|
|
if (tt->scalarType()) {
|
|
return static_cast<ScalarType>(*tt->scalarType());
|
|
}
|
|
}
|
|
return tryScalarTypeFromJitType(*v->type());
|
|
}
|
|
|
|
static bool constZeroDimTensorAsScalarArg(
|
|
const Value* v,
|
|
std::vector<ArgValue>& args) {
|
|
if (v->node()->kind() != prim::Constant || !v->type()->cast<TensorType>()) {
|
|
return false;
|
|
}
|
|
|
|
const auto t = toIValue(v)->toTensor();
|
|
if (!t.sizes().empty()) {
|
|
return false;
|
|
}
|
|
|
|
c10::ScalarType dtype = c10::typeMetaToScalarType(t.dtype());
|
|
switch (dtype) {
|
|
case ScalarType::Float:
|
|
args.emplace_back(t.item().toFloat());
|
|
return true;
|
|
case ScalarType::Long:
|
|
args.emplace_back(t.item().toLong());
|
|
return true;
|
|
default:
|
|
std::stringstream ss;
|
|
ss << "Unsupported tensor dtype:" << dtype
|
|
<< " for converting constant 0-dim Tensor to scalar" << std::endl;
|
|
throw unsupported_dtype(ss.str());
|
|
}
|
|
}
|
|
|
|
Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
|
|
auto inputs = v->node()->inputs();
|
|
auto op = v->node()->kind();
|
|
|
|
if (op == aten::rand_like) {
|
|
hasRandom_ = true;
|
|
}
|
|
|
|
auto outputType = findDtypeForValue(v);
|
|
std::vector<ExprHandle> outputShape = sizesForValue(v);
|
|
std::vector<ExprHandle> outputStrides = {};
|
|
if (memory_layout_policy_ == MemoryLayoutPolicy::kChannelsLastNdContiguous) {
|
|
outputStrides =
|
|
c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
|
|
} else {
|
|
// Default
|
|
outputStrides = c10::fmap<ExprHandle>(make_contiguous_strides(outputShape));
|
|
}
|
|
|
|
std::vector<ArgValue> argInputs;
|
|
if (op == prim::ConstantChunk) {
|
|
auto const& n = v->node();
|
|
argInputs.emplace_back(toArg(inputs[0]));
|
|
argInputs.emplace_back(static_cast<int64_t>(v->offset()));
|
|
argInputs.emplace_back(n->i(attr::dim));
|
|
argInputs.emplace_back(n->i(attr::chunks));
|
|
} else if (op == aten::to) {
|
|
argInputs.emplace_back(toArg(inputs[0]));
|
|
} else if (op == aten::quantize_per_tensor) {
|
|
argInputs.emplace_back(toArg(inputs[0]));
|
|
if (!constZeroDimTensorAsScalarArg(inputs[1], argInputs)) {
|
|
argInputs.emplace_back(toArg(inputs[1]));
|
|
}
|
|
if (!constZeroDimTensorAsScalarArg(inputs[2], argInputs)) {
|
|
argInputs.emplace_back(toArg(inputs[2]));
|
|
}
|
|
argInputs.emplace_back(toArg(inputs[3]));
|
|
} else if (op == aten::conv2d) {
|
|
for (auto inp : inputs) {
|
|
argInputs.emplace_back(toArg(inp));
|
|
}
|
|
// handle optional bias
|
|
if (std::get_if<ArgNone>(&argInputs[2])) {
|
|
Dtype dtype = outputType ? Dtype(*outputType) : kFloat;
|
|
std::vector<ExprHandle> biasShape;
|
|
biasShape.push_back(outputShape[1]);
|
|
auto bias_tensor = at::zeros({outputShape[1].AsNode<LongImm>()->value()});
|
|
unpacked_constant_tensors_.push_back(bias_tensor);
|
|
BufPtr buf = alloc<Buf>(
|
|
"conv2d_bias_opt_" + sanitizeName(v->debugName()),
|
|
ExprHandleVectorToExprVector(biasShape),
|
|
dtype);
|
|
constants_.push_back({buf, bias_tensor.data_ptr()});
|
|
argInputs[2] = BufHandle(buf);
|
|
}
|
|
} else {
|
|
for (auto inp : inputs) {
|
|
argInputs.emplace_back(toArg(inp));
|
|
}
|
|
}
|
|
|
|
if (NNCLoweringFunction custom_lowering = getCustomLoweringFor(op)) {
|
|
return custom_lowering(
|
|
argInputs, outputShape, outputStrides, outputType, device_);
|
|
}
|
|
if (v->node()->maybeSchema()) {
|
|
if (NNCLoweringFunction lowering =
|
|
getStandardLoweringFor(c10::toString(v->node()->schema()))) {
|
|
return lowering(
|
|
argInputs, outputShape, outputStrides, outputType, device_);
|
|
}
|
|
}
|
|
std::string msg = std::string("Unhandled node kind (in computeValue): ") +
|
|
op.toQualString();
|
|
if (v->node()->maybeSchema()) {
|
|
msg += std::string("\nSchema: ") + c10::toString(v->node()->schema());
|
|
}
|
|
throw malformed_input(msg);
|
|
}
|
|
|
|
// True if all the loops in this vector have equal bounds.
|
|
static bool loopBoundsAllEqual(const std::vector<ForPtr>& loops) {
|
|
if (loops.size() <= 1) {
|
|
return true;
|
|
}
|
|
const auto& start = loops.front()->start();
|
|
const auto& stop = loops.front()->stop();
|
|
for (size_t i = 1; i < loops.size(); ++i) {
|
|
const auto& curr_start = loops[i]->start();
|
|
const auto& curr_stop = loops[i]->stop();
|
|
if (!exprEquals(start, curr_start) || !exprEquals(stop, curr_stop)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Recursively fuse all the loops with matching bounds in `st`. Stops fusing
|
|
// at any level containing non-loops or non-matching bounds. The restriction
|
|
// on matching bounds exists to avoid inserting conditionals on the loop
|
|
// indices where none would be needed, which would significantly complicate
|
|
// vectorization.
|
|
static void fuseAllLoops(StmtPtr st) {
|
|
auto block = to<tensorexpr::Block>(st);
|
|
if (block == nullptr) {
|
|
return;
|
|
}
|
|
|
|
std::vector<std::vector<ForPtr>> all_outer_loops;
|
|
std::vector<ForPtr> outer_loops;
|
|
for (const auto& stmt : *block) {
|
|
auto loop = to<For>(stmt);
|
|
auto hasReduction = !NodeFinder<ReduceOp>::find(stmt).empty();
|
|
if (!loop || hasReduction) {
|
|
all_outer_loops.push_back(outer_loops);
|
|
outer_loops.clear();
|
|
} else {
|
|
outer_loops.push_back(loop);
|
|
}
|
|
}
|
|
all_outer_loops.push_back(outer_loops);
|
|
|
|
for (const auto& outer_loops : all_outer_loops) {
|
|
if (outer_loops.empty()) {
|
|
continue;
|
|
}
|
|
|
|
if (!loopBoundsAllEqual(outer_loops)) {
|
|
continue;
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
ForPtr fusedLoop;
|
|
if (!LoopNest::fuseLoops(outer_loops, &fusedLoop)) {
|
|
continue;
|
|
}
|
|
|
|
fuseAllLoops(fusedLoop->body());
|
|
}
|
|
}
|
|
|
|
// Compute the trip count of a loop if it is a constant.
|
|
static std::optional<int64_t> tripCount(ForPtr loop) {
|
|
auto tc = IRSimplifier::simplify(
|
|
cast<int64_t>(ExprHandle(loop->stop()) - ExprHandle(loop->start())));
|
|
if (auto val = to<LongImm>(tc.node())) {
|
|
return val->value();
|
|
}
|
|
return c10::nullopt;
|
|
}
|
|
|
|
// Prune innermost loops until iterations satisfies a minimum grain size.
|
|
static void pruneByGrainSize(std::vector<ForPtr>& loops) {
|
|
constexpr int64_t minGrainSize = 32768;
|
|
int64_t grainSize = 1;
|
|
for (int64_t i = loops.size(); i > 0; i--) {
|
|
auto tc = tripCount(loops[i - 1]);
|
|
if (!tc) {
|
|
break;
|
|
}
|
|
grainSize *= *tc;
|
|
if (grainSize < minGrainSize) {
|
|
loops.pop_back();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Retain enough outermost loops to fill the number of threads.
|
|
static void pruneByThreadCount(std::vector<ForPtr>& loops) {
|
|
int64_t trips = 1;
|
|
auto threads = at::get_num_threads();
|
|
auto it = loops.begin();
|
|
for (; it != loops.end(); it++) {
|
|
if (trips >= threads) {
|
|
break;
|
|
}
|
|
auto tc = tripCount(*it);
|
|
if (!tc) {
|
|
break;
|
|
}
|
|
trips *= *tc;
|
|
}
|
|
loops.erase(it, loops.end());
|
|
}
|
|
|
|
// Flatten and parallelize outer loops, subject to a minimum number of elements
|
|
// in the inner loop, and a maximum level of thread-level parallelism in the
|
|
// outer loops.
|
|
template <typename Bufs>
|
|
static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) {
|
|
for (auto const& buf : bufs) {
|
|
auto loops = l.getLoopStmtsFor(buf);
|
|
pruneByGrainSize(loops);
|
|
pruneByThreadCount(loops);
|
|
|
|
// There are no loops to parallelize; give up.
|
|
if (loops.size() == 0) {
|
|
continue;
|
|
}
|
|
// The loop nest contains a reduction; give up.
|
|
auto reductions = NodeFinder<ReduceOp>::find(loops[0]);
|
|
if (reductions.size() > 0) {
|
|
continue;
|
|
}
|
|
// The loop nest has loop carried dependences; give up.
|
|
if (LoopNest::hasLoopCarriedDependence(loops[0])) {
|
|
continue;
|
|
}
|
|
// Try to flatten the outer loops and parallelize them if successful.
|
|
ForPtr flattened = nullptr;
|
|
if (loops.size() == 1) {
|
|
flattened = loops[0];
|
|
} else {
|
|
LoopNest::flatten(loops, &flattened);
|
|
}
|
|
if (flattened) {
|
|
flattened->set_parallel();
|
|
}
|
|
}
|
|
}
|
|
|
|
StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
|
|
torch::jit::tensorexpr::LoopNest l(st, bufOutputs_);
|
|
LoopNest::sanitizeNames(l.root_stmt());
|
|
GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
|
|
int64_t random_tr_seed = randomTransformsRequested();
|
|
if (random_tr_seed) {
|
|
if (random_tr_seed == -1)
|
|
random_tr_seed = std::time(nullptr);
|
|
loopnestRandomization(random_tr_seed, l);
|
|
GRAPH_DEBUG(
|
|
"After random transform:\n", std::to_string(l.root_stmt()), "\n");
|
|
}
|
|
|
|
bool hasReduction = !NodeFinder<ReduceOp>::find(l.root_stmt()).empty();
|
|
|
|
// For Block codegen we create a map of tensor dims before
|
|
// inlining. Like GPU codegen we need to inline. But the order
|
|
// where this analysis is run matters.
|
|
auto block_analysis = std::make_unique<CreateBufferMap>();
|
|
if (backendType == kBlockCodeGen) {
|
|
// Run Block analysis to get multi dim buffer info
|
|
auto root_stmt = l.root_stmt();
|
|
root_stmt->accept(block_analysis.get());
|
|
}
|
|
l.simplify();
|
|
GRAPH_DEBUG("after simplify", *l.root_stmt());
|
|
|
|
// Inlining output & intermediate buffers can duplicate computation.
|
|
// Duplicating work can slow down the program if it's not ameliorated in some
|
|
// way, but we've empirically found that:
|
|
// - On CPU, LLVM's CSE does a good job as long as you horizontally fuse
|
|
// output loops.
|
|
// - On GPU, there's enough compute to hide the extra work, and inlining
|
|
// avoids synchronizing between kernels.
|
|
l.inlineIntermediateBufs(/*allow_duplicated_work=*/true);
|
|
GRAPH_DEBUG("after inline", *l.root_stmt());
|
|
|
|
// Optimizing conditionals needs to be performed after inlining because
|
|
// inlining wouldn't work once the loops are split. Also, it has to be
|
|
// performed before loop fusion because loop fusion introduces cases where
|
|
// multiple conditionals are in the same loop and this optimization does not
|
|
// handle such cases yet.
|
|
if (getOptConditionals()) {
|
|
l.optimizeConditionals();
|
|
GRAPH_DEBUG("after optimizing conditionals: ", *l.root_stmt());
|
|
}
|
|
|
|
// Fuse loops "horizontally". This pass allows us to combine loops that
|
|
// write to different output buffers, as long as they have the same bounds.
|
|
if (backendType == kLLVMCodeGen) {
|
|
fuseAllLoops(l.root_stmt());
|
|
GRAPH_DEBUG("after fuse", *l.root_stmt());
|
|
parallelizeOuterLoops(l, bufsToBeParallelized_);
|
|
GRAPH_DEBUG("after parallelize", *l.root_stmt());
|
|
}
|
|
|
|
if (backendType == kCudaCodeGen) {
|
|
for (const auto& buf : bufOutputs_) {
|
|
std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
|
|
if (loops.empty()) {
|
|
// This happens when Buf is 0-dim
|
|
continue;
|
|
}
|
|
ForPtr flattened = nullptr;
|
|
LoopNest::flatten(loops, &flattened);
|
|
assert(flattened);
|
|
|
|
int loopLevels = getTECudaPointwiseLoopLevels();
|
|
const int kDefaultLoopLevels = 2;
|
|
loopLevels = (loopLevels > 0) ? loopLevels : kDefaultLoopLevels;
|
|
int blockCount = getTECudaPointwiseBlockCount();
|
|
int blockSize = getTECudaPointwiseBlockSize();
|
|
|
|
if (loopLevels == 2) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
ForPtr inner;
|
|
const int kDefaultBlockSize = 512;
|
|
if (blockSize < 0) {
|
|
blockSize = kDefaultBlockSize;
|
|
}
|
|
LoopNest::splitWithMask(flattened, blockSize, &inner);
|
|
flattened->set_gpu_block_index(0);
|
|
inner->set_gpu_thread_index(0);
|
|
} else if (loopLevels == 3) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
ForPtr inner;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
ForPtr inner1;
|
|
// TODO: change the number of microprocessors
|
|
const int kDefaultBlockCount = 1280;
|
|
const int kDefaultBlockSize = 256;
|
|
blockCount = (blockCount > 0) ? blockCount : kDefaultBlockCount;
|
|
blockSize = (blockSize > 0) ? blockSize : kDefaultBlockSize;
|
|
LoopNest::splitWithMask(flattened, blockCount * blockSize, &inner);
|
|
LoopNest::splitWithMask(inner, blockSize, &inner1);
|
|
inner->set_gpu_block_index(0);
|
|
inner1->set_gpu_thread_index(0);
|
|
} else {
|
|
throw std::runtime_error(
|
|
"Invalid loop-level: " + std::to_string(loopLevels));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (backendType == kBlockCodeGen) {
|
|
for (const auto& buf : bufOutputs_) {
|
|
const int default_fp16_blocksize = 16;
|
|
const int default_uint8_blocksize = 32;
|
|
int blockSize = default_fp16_blocksize;
|
|
// We only handle looplevels == 2 for now
|
|
if (buf->dtype().scalar_type() == ScalarType::Byte) {
|
|
blockSize = default_uint8_blocksize;
|
|
}
|
|
std::vector<ForPtr> loops = l.getLoopStmtsFor(buf);
|
|
TORCH_INTERNAL_ASSERT(
|
|
!loops.empty(),
|
|
buildErrorMessage(
|
|
"No loops found for the buffer " + buf->name_hint() +
|
|
" in the fuser."));
|
|
ForPtr flattened = nullptr;
|
|
LoopNest::flatten(loops, &flattened);
|
|
assert(flattened);
|
|
|
|
ForPtr inner = nullptr;
|
|
LoopNest::splitWithMask(flattened, blockSize, &inner);
|
|
flattened->set_gpu_block_index(0);
|
|
inner->set_gpu_thread_index(0);
|
|
flattened->set_buffer_map(block_analysis->getBufferMap());
|
|
}
|
|
}
|
|
|
|
if (pre_alloc_) {
|
|
auto interm_bufs = l.getIntermediateBufs();
|
|
preAllocIntermediateBufs(interm_bufs);
|
|
}
|
|
|
|
l.prepareForCodegen();
|
|
|
|
GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt());
|
|
l.simplify();
|
|
GRAPH_DEBUG("after simplification", *l.root_stmt());
|
|
|
|
if (backendType == kLLVMCodeGen && !hasReduction) {
|
|
l.vectorizeInnerLoops();
|
|
GRAPH_DEBUG("after vectorization", *l.root_stmt());
|
|
}
|
|
|
|
StmtPtr stmt = l.root_stmt();
|
|
// Arithmetic Simplification.
|
|
stmt = IRSimplifier::simplify(stmt);
|
|
GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), "\n");
|
|
return stmt;
|
|
}
|
|
|
|
std::string TensorExprKernel::getCodeGenName(BackendType backendType) {
|
|
switch (backendType) {
|
|
case kCudaCodeGen:
|
|
return "cuda_codegen";
|
|
case kLLVMCodeGen:
|
|
return "llvm_codegen";
|
|
case kSimpleIREval:
|
|
return "simple_ir_eval";
|
|
case kBlockCodeGen:
|
|
return "block_codegen";
|
|
default:
|
|
throw std::runtime_error(
|
|
"invalid backend type: " +
|
|
std::to_string(static_cast<int>(backendType)));
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
static bool isValidPrimProperty(const std::optional<T>& a, T b) {
|
|
return !a.has_value() || *a == b;
|
|
}
|
|
|
|
TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice(
|
|
at::Device device) {
|
|
BackendType backendType = BackendType::kUninitialized;
|
|
if (device.type() == at::kCUDA) {
|
|
backendType = kCudaCodeGen;
|
|
} else if (device.type() == at::kCPU && getTEGenerateBlockCode()) {
|
|
backendType = kBlockCodeGen;
|
|
} else if (device.type() == at::kCPU) {
|
|
#ifdef TORCH_ENABLE_LLVM
|
|
backendType = dontUseLLVMFlag() ? kSimpleIREval : kLLVMCodeGen;
|
|
#else
|
|
backendType = kSimpleIREval;
|
|
#endif
|
|
if (getTEMustUseLLVMOnCPU() && backendType == kSimpleIREval) {
|
|
throw std::runtime_error("LLVM Backend not found");
|
|
}
|
|
} else {
|
|
throw std::runtime_error("Invalid device type");
|
|
}
|
|
return backendType;
|
|
}
|
|
|
|
// we use the debug names in printing cuda code, they need to be removed
|
|
// of characters that can't be used in a variable identifier
|
|
void TensorExprKernel::genInputDebugNames() {
|
|
std::unordered_map<std::string, const torch::jit::Value*> name_to_value;
|
|
std::unordered_set<std::string> name_set;
|
|
std::unordered_map<const torch::jit::Value*, std::string> value_to_name;
|
|
for (const torch::jit::Value* input : graph_->inputs()) {
|
|
std::string sanitized_name = sanitizeName(input->debugName());
|
|
// we could get fancier here, but name conflict is extremely unlikely
|
|
while (name_set.count(sanitized_name)) {
|
|
sanitized_name.append("_");
|
|
}
|
|
value_to_name[input] = sanitized_name;
|
|
name_set.insert(sanitized_name);
|
|
}
|
|
input_name_map_ = std::move(value_to_name);
|
|
}
|
|
|
|
template <typename T>
|
|
static std::vector<ExprHandle> toExprHandles(const std::vector<T>& sizes) {
|
|
std::vector<ExprHandle> dims;
|
|
dims.reserve(sizes.size());
|
|
for (auto const& size : sizes) {
|
|
dims.emplace_back(size);
|
|
}
|
|
return dims;
|
|
}
|
|
|
|
ExprHandle TensorExprKernel::getStrideArg(
|
|
size_t tensor_input_index,
|
|
size_t stride_index) {
|
|
auto it = strideArgToVar_.find(
|
|
std::pair<size_t, size_t>(tensor_input_index, stride_index));
|
|
if (it == strideArgToVar_.end()) {
|
|
VarHandle var(
|
|
"stride_arg" + std::to_string(tensor_input_index) + "_" +
|
|
std::to_string(stride_index),
|
|
kLong);
|
|
strideArgToVar_[std::pair<size_t, size_t>(
|
|
tensor_input_index, stride_index)] = var;
|
|
return std::move(var);
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
std::vector<torch::jit::StrideInput>& TensorExprKernel::getSymbolicStrideDesc(
|
|
const torch::jit::Value* value) {
|
|
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(value));
|
|
return symbolic_strides_[value];
|
|
}
|
|
|
|
std::vector<ExprHandle> TensorExprKernel::getInputStrides(
|
|
const torch::jit::Value* input,
|
|
const std::vector<ExprHandle>& inputTensorDims) {
|
|
std::vector<ExprHandle> inputTensorStrides;
|
|
if (input->isCompleteTensor()) {
|
|
auto const strides =
|
|
input->type()->expect<TensorType>()->strides().concrete_sizes();
|
|
std::vector<ExprHandle> inputTensorStrides;
|
|
for (size_t stride : *strides) {
|
|
inputTensorStrides.push_back(LongImm::make(stride));
|
|
}
|
|
return inputTensorStrides;
|
|
}
|
|
|
|
size_t rank = inputTensorDims.size();
|
|
std::vector<StrideInput>& stride_input = getSymbolicStrideDesc(input);
|
|
if (stride_input.size() == 1 &&
|
|
(stride_input[0] == StrideInput::TENSOR_CONT_CHANNELS_LAST ||
|
|
stride_input[0] == StrideInput::TENSOR_CONT)) {
|
|
auto strides = stride_input[0] == StrideInput::TENSOR_CONT
|
|
? make_contiguous_strides(inputTensorDims)
|
|
: make_channels_last_strides(inputTensorDims);
|
|
return fmap(strides, [&](ExprPtr stride) { return ExprHandle(stride); });
|
|
}
|
|
|
|
inputTensorStrides.resize(rank);
|
|
std::vector<bool> stride_set;
|
|
for (size_t i = 0; i < rank; ++i) {
|
|
stride_set.push_back(false);
|
|
}
|
|
// first, generate non-dependent values
|
|
size_t generated_strides = 0;
|
|
for (const auto i : c10::irange(rank)) {
|
|
if (stride_input[i] == torch::jit::StrideInput::S_ONE) {
|
|
inputTensorStrides[i] = LongImm::make(1);
|
|
stride_set[i] = true;
|
|
generated_strides++;
|
|
} else if (stride_input[i] == torch::jit::StrideInput::S_AS_ARG) {
|
|
size_t input_index = input->offset();
|
|
inputTensorStrides[i] = getStrideArg(input_index, i);
|
|
stride_set[i] = true;
|
|
generated_strides++;
|
|
}
|
|
}
|
|
// Contiguous and Transposed Contiguous depend on adjacent values
|
|
while (generated_strides != rank) {
|
|
for (int i = static_cast<int>(rank) - 1; i >= 0; i--) {
|
|
if (stride_input[i] == torch::jit::StrideInput::S_CONT &&
|
|
stride_set[i + 1]) {
|
|
inputTensorStrides[i] =
|
|
inputTensorStrides[i + 1] * inputTensorDims[i + 1];
|
|
|
|
stride_set[i] = true;
|
|
generated_strides++;
|
|
}
|
|
}
|
|
for (int i = 0; i < static_cast<int>(rank); i++) {
|
|
if (stride_input[i] == torch::jit::StrideInput::S_TRAN_CONT &&
|
|
stride_set[i - 1]) {
|
|
inputTensorStrides[i] =
|
|
inputTensorStrides[i - 1] * inputTensorDims[i - 1];
|
|
stride_set[i] = true;
|
|
generated_strides++;
|
|
}
|
|
}
|
|
}
|
|
return inputTensorStrides;
|
|
}
|
|
|
|
Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
|
|
auto const& t = input->type();
|
|
auto const& outputs = input->owningGraph()->outputs();
|
|
std::unordered_set<const Value*> outputs_set(outputs.begin(), outputs.end());
|
|
|
|
auto is_concrete_cont = [](const torch::jit::Value* input,
|
|
const MemoryLayoutPolicy& mem_layout_policy) {
|
|
if (input->isCompleteTensor()) {
|
|
auto mem_layout = (mem_layout_policy == MemoryLayoutPolicy::kContiguous)
|
|
? at::MemoryFormat::Contiguous
|
|
: at::MemoryFormat::ChannelsLast;
|
|
return isContiguous(input, mem_layout);
|
|
} else {
|
|
return false;
|
|
}
|
|
};
|
|
|
|
auto is_symbolic_cont = [](std::vector<torch::jit::StrideInput> desc,
|
|
const MemoryLayoutPolicy& mem_layout_policy) {
|
|
if (desc.size() == 1) {
|
|
auto mem_layout = (mem_layout_policy == MemoryLayoutPolicy::kContiguous)
|
|
? torch::jit::StrideInput::TENSOR_CONT
|
|
: torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST;
|
|
return desc[0] == mem_layout;
|
|
} else {
|
|
return false;
|
|
}
|
|
};
|
|
|
|
Tensor result(nullptr, nullptr);
|
|
switch (t->kind()) {
|
|
case TypeKind::TensorType: {
|
|
auto tt = input->type()->cast<TensorType>();
|
|
bool contiguous_concrete_tensor =
|
|
is_concrete_cont(input, memory_layout_policy_);
|
|
bool contiguous_symbolic_tensor = false;
|
|
if (has_symbolic_shapes_) {
|
|
auto desc = getSymbolicStrideDesc(input);
|
|
contiguous_symbolic_tensor =
|
|
is_symbolic_cont(desc, memory_layout_policy_);
|
|
}
|
|
|
|
// Get input size and strides
|
|
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
|
|
auto inputTensorStrides = getInputStrides(input, size_handles);
|
|
|
|
// We don't need to copy the input if:
|
|
// 1) it is not an output AND
|
|
// 2) it is contiguous
|
|
bool contiguous =
|
|
contiguous_concrete_tensor || contiguous_symbolic_tensor;
|
|
if (!outputs_set.count(input) && contiguous) {
|
|
BufHandle inBuffer(
|
|
"t" + input_name_map_[input],
|
|
sizesFromSymbolicShape(tt->symbolic_sizes()),
|
|
inputTensorStrides,
|
|
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
inBuffer.node()->is_contiguous() ||
|
|
inBuffer.node()->is_channels_last_1d_contiguous() ||
|
|
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast) ||
|
|
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast3d));
|
|
bufs_.emplace(input, inBuffer.node());
|
|
bufferArgs_.emplace_back(inBuffer);
|
|
break;
|
|
}
|
|
|
|
// if the input isn't contiguous or is an output,
|
|
// write strided input into contiguous buffer that is
|
|
// then used in all further compute
|
|
ExprHandle flat_size = 1;
|
|
for (size_t i = 0; i < size_handles.size(); ++i) {
|
|
auto size = size_handles[i];
|
|
if (size.AsNode<LongImm>() && immediateAs<int64_t>(size.node()) == 0) {
|
|
flat_size = 0;
|
|
break;
|
|
}
|
|
flat_size = flat_size + (size - 1) * inputTensorStrides[i];
|
|
}
|
|
flat_size = IRSimplifier::simplify(flat_size);
|
|
BufHandle inBuffer(
|
|
"t" + input_name_map_[input],
|
|
{flat_size},
|
|
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
|
|
|
|
result = Compute(
|
|
"input" + std::to_string(bufs_.size() + 1),
|
|
size_handles,
|
|
[&](const std::vector<VarHandle>& axes) {
|
|
ExprHandle idx = 0;
|
|
for (size_t i = 0; i < axes.size(); i++) {
|
|
idx = idx + axes[i] * inputTensorStrides[i];
|
|
}
|
|
return inBuffer.load(idx);
|
|
});
|
|
bufs_.emplace(input, result.buf());
|
|
bufferArgs_.emplace_back(inBuffer);
|
|
break;
|
|
}
|
|
case TypeKind::FloatType: {
|
|
VarHandle v("v" + input_name_map_[input], kDouble);
|
|
bufferArgs_.emplace_back(v);
|
|
scalars_.emplace(input, v);
|
|
break;
|
|
}
|
|
case TypeKind::BoolType: {
|
|
VarHandle v("v" + input_name_map_[input], kBool);
|
|
bufferArgs_.emplace_back(v);
|
|
scalars_.emplace(input, v);
|
|
break;
|
|
}
|
|
case TypeKind::IntType: {
|
|
VarHandle v("v" + input_name_map_[input], kLong);
|
|
bufferArgs_.emplace_back(v);
|
|
scalars_.emplace(input, v);
|
|
break;
|
|
}
|
|
default: {
|
|
throw unsupported_dtype(t->repr_str());
|
|
break;
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
NNCLoweringFunction TensorExprKernel::getCustomLoweringFor(
|
|
c10::Symbol op) const {
|
|
if (custom_lowerings_.count(op))
|
|
return custom_lowerings_.at(op);
|
|
return nullptr;
|
|
}
|
|
|
|
template <typename T>
|
|
std::vector<size_t> reverse_sort_indices(const std::vector<T>& v) {
|
|
// initialize original index locations
|
|
std::vector<size_t> idx(v.size());
|
|
iota(idx.begin(), idx.end(), 0);
|
|
|
|
std::sort(idx.begin(), idx.end(), [&v](size_t i1, size_t i2) {
|
|
return v[i1] > v[i2];
|
|
});
|
|
return idx;
|
|
}
|
|
|
|
static bool denseAndNonOverlapping(
|
|
at::ArrayRef<int64_t> sizes,
|
|
at::ArrayRef<int64_t> strides) {
|
|
return (strides == at::infer_dense_strides(sizes, strides));
|
|
}
|
|
|
|
Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
|
|
const std::vector<ExprHandle>& sizes,
|
|
const std::vector<size_t>& sorted_stride_indices_descending,
|
|
const std::vector<ExprPtr>& strides,
|
|
BufPtr& buf) {
|
|
// We need to convert the output tensor so that its values are layed
|
|
// so that when viewed from the output strides the values are correct.
|
|
// A contiguous Tensor of size(2, 3) with values 0-5 is layed out as:
|
|
// [0] [1] [2] [3] [4] [5]
|
|
// The same valued tensor with strides (1, 2) would be layed out like
|
|
// [0] [3] [1] [4] [2] [5]
|
|
// When we are doing the re-ordering of values into the output tensor,
|
|
// we are iterating per-element of the input, and we are fixed
|
|
// in indexing in to the output tensor at [i, j] = val
|
|
// `val` we want here is equal to the indices for the output
|
|
// tensor that would have given the same position as the output
|
|
// The position is equal to the sum of stride[i] * index[i],
|
|
// and we can can calculate the equivalent indices in the
|
|
// output tensor strides by iteratively computing the index of
|
|
// the biggest stride:
|
|
// absolute = ...
|
|
// for stride in strides_from_largest_to_smallest:
|
|
// cur_idx = absolute // stride
|
|
// absolute = absolute % stride
|
|
std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
|
|
auto zero = LongImm::make(0);
|
|
return Compute(
|
|
"output_1", sizes, [&](const std::vector<VarHandle>& axes_input) {
|
|
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
|
|
auto absolute_position = ExprHandle(immLike(axes[0], 0));
|
|
for (size_t i = 0; i < axes.size(); ++i) {
|
|
ExprHandle stride(default_strides[i]);
|
|
ExprHandle axis = axes[i];
|
|
absolute_position = absolute_position + (stride * axis);
|
|
}
|
|
std::vector<ExprHandle> new_axes(
|
|
sorted_stride_indices_descending.size());
|
|
for (size_t stride_index : sorted_stride_indices_descending) {
|
|
const auto& stride = strides[stride_index];
|
|
auto index = absolute_position / ExprHandle(stride);
|
|
// XXX, in symbolic output ordering, we do not the arbitrary
|
|
// ordering of strides as in usual output ordering, just
|
|
// channels last, so even in the presence of size == 1
|
|
// we produce correct output here
|
|
absolute_position = absolute_position % ExprHandle(stride);
|
|
new_axes[stride_index] = index;
|
|
}
|
|
return BufHandle(buf).load(new_axes);
|
|
});
|
|
}
|
|
|
|
Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
|
|
torch::jit::Value* v) {
|
|
const TensorTypePtr& tt = v->type()->expect<TensorType>();
|
|
TORCH_INTERNAL_ASSERT(
|
|
bufs_.count(v),
|
|
buildErrorMessage(
|
|
"Output tensor has no corresponding bufs in the fuser."));
|
|
BufPtr buf = bufs_.at(v);
|
|
TORCH_INTERNAL_ASSERT(buf != nullptr);
|
|
TORCH_INTERNAL_ASSERT(tt != nullptr);
|
|
TORCH_INTERNAL_ASSERT(tt->symbolic_sizes().rank() != c10::nullopt);
|
|
|
|
auto stride_desc = getSymbolicStrideDesc(v);
|
|
TORCH_INTERNAL_ASSERT(stride_desc.size() == 1);
|
|
auto memory_format = (stride_desc[0] == torch::jit::StrideInput::TENSOR_CONT)
|
|
? at::MemoryFormat::Contiguous
|
|
: at::MemoryFormat::ChannelsLast;
|
|
// output is contiguous with specified memory format, no work to do
|
|
if (buf->is_contiguous(memory_format)) {
|
|
return Tensor(buf, nullptr);
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
stride_desc[0] == torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
|
|
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
|
|
auto strides = make_channels_last_strides(sizes);
|
|
// For a tensor with dimensions N C H W, channels last
|
|
// format will is in format N H W C,
|
|
// so the order largest to smallest will be N, H, W, C
|
|
std::vector<size_t> sorted_stride_indices = {0, 2, 3, 1};
|
|
auto zero = LongImm::make(0);
|
|
std::vector<ExprPtr> default_strides = make_contiguous_strides(sizes);
|
|
// See explanation in convertOutputToCorrectStrides
|
|
return convertSymbolicOutputToCorrectStrides(
|
|
sizes, sorted_stride_indices, strides, buf);
|
|
}
|
|
|
|
Tensor TensorExprKernel::convertStaticShapeOutputToCorrectStrides(
|
|
torch::jit::Value* v) {
|
|
const TensorTypePtr& tt = v->type()->expect<TensorType>();
|
|
TORCH_INTERNAL_ASSERT(
|
|
bufs_.count(v),
|
|
buildErrorMessage(
|
|
"Output tensor has no corresponding bufs in the fuser."));
|
|
BufPtr buf = bufs_.at(v);
|
|
|
|
// No shape info is present in the graph
|
|
if (!tt->sizes().concrete_sizes()) {
|
|
std::string msg =
|
|
std::string("Shapes for output '%") + v->debugName() + "' are unknown";
|
|
throw malformed_input(msg);
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
tt->sizes().concrete_sizes(),
|
|
buildErrorMessage("Output shapes are unknown."));
|
|
auto sizes = *tt->sizes().concrete_sizes();
|
|
at::MemoryFormat memory_format =
|
|
(memory_layout_policy_ == MemoryLayoutPolicy::kContiguous)
|
|
? c10::MemoryFormat::Contiguous
|
|
: c10::MemoryFormat::ChannelsLast;
|
|
std::vector<int64_t> default_strides =
|
|
TensorType::contiguousStridesOf(sizes, memory_format);
|
|
if (!tt->strides().concrete_sizes()) {
|
|
return Tensor(buf, nullptr);
|
|
}
|
|
TORCH_INTERNAL_ASSERT(
|
|
tt->strides().concrete_sizes(),
|
|
buildErrorMessage("Output strides are unknown."));
|
|
const std::vector<int64_t> strides = *tt->strides().concrete_sizes();
|
|
// All Tensors in NNC are layed out in default, contiguous layout.
|
|
// If the output is also default contiguous we don't need to do anything
|
|
if (strides == default_strides) {
|
|
return Tensor(buf, nullptr);
|
|
}
|
|
// If the tensor is not dense or overlaps, we have
|
|
// no way of matching the profiled striding
|
|
if (!denseAndNonOverlapping(sizes, strides)) {
|
|
return Tensor(buf, nullptr);
|
|
}
|
|
|
|
auto dims = sizesForValue(v);
|
|
auto zero = LongImm::make(0);
|
|
std::vector<size_t> sorted_stride_indices = reverse_sort_indices(strides);
|
|
|
|
// TODO: call into `convertOutputToCorrectStrides`. Currently this causes a
|
|
// bug in IRSimplifier to occur. See explanation in
|
|
// `convertOutputToCorrectStrides`
|
|
return Compute(
|
|
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
|
|
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
|
|
auto absolute_position = ExprHandle(immLike(axes[0], 0));
|
|
for (size_t i = 0; i < axes.size(); ++i) {
|
|
absolute_position = absolute_position +
|
|
(ExprHandle(immLike(axes[i], default_strides[i])) * axes[i]);
|
|
}
|
|
|
|
std::vector<ExprHandle> new_axes(sorted_stride_indices.size());
|
|
for (size_t stride_index : sorted_stride_indices) {
|
|
auto size = sizes[stride_index];
|
|
auto index = zero;
|
|
if (size != 1) {
|
|
auto stride = strides[stride_index];
|
|
index = absolute_position /
|
|
ExprHandle(immLike(absolute_position, stride));
|
|
absolute_position = absolute_position %
|
|
ExprHandle(immLike(absolute_position, stride));
|
|
}
|
|
new_axes[stride_index] = index;
|
|
}
|
|
return BufHandle(buf).load(new_axes);
|
|
});
|
|
}
|
|
|
|
void TensorExprKernel::bindConstant(const torch::jit::Value* v) {
|
|
auto val = toIValue(v).value();
|
|
if (torch::isCustomClass(val)) {
|
|
auto name_hint = "const_" + sanitizeName(v->debugName());
|
|
auto dtype = Dtype(ScalarType::Float);
|
|
std::vector<ExprPtr> dims;
|
|
BufPtr buf = alloc<Buf>(name_hint, dims, dtype);
|
|
auto dataPtr = val.toObjectRef().getSlot(0).toCapsule().get();
|
|
// NOLINTNEXTLINE
|
|
constants_.push_back({buf, dataPtr, const_cast<Node*>(v->node())});
|
|
bufs_[v] = buf;
|
|
return;
|
|
}
|
|
if (!v->type()->cast<TensorType>()) {
|
|
// Only Tensor constants need to be bound, scalar constants will be turned
|
|
// into immediates in TE IR
|
|
return;
|
|
}
|
|
auto const_tensor = toIValue(v)->toTensor();
|
|
auto scalar_type = c10::typeMetaToScalarType(const_tensor.options().dtype());
|
|
auto sizes = const_tensor.sizes();
|
|
std::vector<ExprHandle> te_sizes;
|
|
te_sizes.reserve(sizes.size());
|
|
for (auto s : sizes) {
|
|
te_sizes.emplace_back(s);
|
|
}
|
|
BufPtr buf = alloc<Buf>(
|
|
"const_" + sanitizeName(v->debugName()),
|
|
ExprHandleVectorToExprVector(te_sizes),
|
|
ToDtype(scalar_type));
|
|
|
|
if (!const_tensor.is_contiguous()) {
|
|
const_tensor = const_tensor.clone().contiguous();
|
|
unpacked_constant_tensors_.push_back(const_tensor);
|
|
}
|
|
|
|
constants_.push_back({buf, const_tensor.data_ptr()});
|
|
bufs_[v] = buf;
|
|
}
|
|
|
|
std::vector<BufPtr> TensorExprKernel::preAllocIntermediateBufs(
|
|
const std::vector<BufPtr>& interm_bufs) {
|
|
std::vector<BufPtr> remaining_interm_bufs;
|
|
for (const auto& buf : interm_bufs) {
|
|
// Check if buf shape is static and compute its size if static.
|
|
bool is_static = true;
|
|
size_t size =
|
|
elementSize(buf->dtype().scalar_type()) * buf->dtype().lanes();
|
|
for (auto& d : buf->dims()) {
|
|
if (!d->isConstant()) {
|
|
is_static = false;
|
|
break;
|
|
}
|
|
size = size * (*intValue(d));
|
|
}
|
|
// Only allocate memory for static bufs.
|
|
if (!is_static) {
|
|
remaining_interm_bufs.push_back(buf);
|
|
continue;
|
|
}
|
|
auto bp = (void*)malloc(size);
|
|
if (!bp) {
|
|
remaining_interm_bufs.push_back(buf);
|
|
continue;
|
|
}
|
|
constants_.push_back({buf, bp});
|
|
}
|
|
return remaining_interm_bufs;
|
|
}
|
|
|
|
BlockPtr TensorExprKernel::bindAllInputs() {
|
|
std::vector<CodeGen::BufferArg> symbolic_shape_args;
|
|
std::vector<CodeGen::BufferArg> symbolic_stride_args;
|
|
|
|
auto symbolic_shape_inputs_start_pos =
|
|
nInputs_ - symbolic_shape_inputs_.size();
|
|
if (has_symbolic_shapes_) {
|
|
// The graph is supposed to have input params that represent the symbolic
|
|
// dims at the end of the list of inputs. The number of such symbolic input
|
|
// params is defined by the size of the `symbolic_shape_inputs_` vector.
|
|
//
|
|
// TODO: Check if the tensors with symbolic shapes are contiguous.
|
|
TORCH_CHECK(
|
|
nInputs_ > static_cast<int64_t>(symbolic_shape_inputs_.size()),
|
|
"Symbolic dims not provided as inputs to the graph");
|
|
|
|
// First, process the symbolic input params and create a new variable for
|
|
// each of them.
|
|
// NOTE: This has to be done before processing the tensor inputs, because
|
|
// their symbolic sizes needs to be associated with these variables we
|
|
// create for the symbolic input params.
|
|
symbolic_shape_args.reserve(symbolic_shape_inputs_.size());
|
|
|
|
for (size_t i = symbolic_shape_inputs_start_pos;
|
|
i < static_cast<size_t>(nInputs_);
|
|
++i) {
|
|
auto input = graph_->inputs()[i];
|
|
if (input->type()->kind() != TypeKind::IntType) {
|
|
throw std::runtime_error(
|
|
"Expected integer type input to graph for symbolic dims.");
|
|
}
|
|
VarHandle v("v" + input_name_map_[input], kLong);
|
|
symbolic_shape_args.emplace_back(v);
|
|
scalars_.emplace(input, v);
|
|
shapeSymbolInputPos_[scalars_[input].node()] = i;
|
|
}
|
|
// For every shape symbol, store a map to the corresponding var.
|
|
for (size_t i = 0; i < symbolic_shape_inputs_.size(); ++i) {
|
|
shapeSymbolToVar_[symbolic_shape_inputs_[i]] =
|
|
scalars_[graph_->inputs()[symbolic_shape_inputs_start_pos + i]];
|
|
}
|
|
|
|
// Next, process symbolic input params and create an argument for symbolic
|
|
for (size_t i = 0; i < symbolic_shape_inputs_start_pos; ++i) {
|
|
auto input = graph_->inputs()[i];
|
|
auto tt = input->type()->cast<TensorType>();
|
|
if (!tt) {
|
|
continue;
|
|
}
|
|
auto symbolic_stride = getSymbolicStrideDesc(input);
|
|
for (size_t j = 0; j < symbolic_stride.size(); ++j) {
|
|
if (symbolic_stride[j] == torch::jit::StrideInput::S_AS_ARG) {
|
|
VarHandle v("v" + input_name_map_[input], kLong);
|
|
symbolic_stride_args.emplace_back(v);
|
|
strideArgToVar_[{i, j}] = v;
|
|
input_stride_args_.emplace_back(i, j);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Block to collect the Stmts corresponding to all tensors.
|
|
auto block = alloc<Block>(std::vector<StmtPtr>({}));
|
|
|
|
// Process the inputs before the symbolic input params.
|
|
for (const auto i : c10::irange(symbolic_shape_inputs_start_pos)) {
|
|
auto input = graph_->inputs()[i];
|
|
Tensor t = bindInput(input);
|
|
if (t.stmt()) {
|
|
block->append_stmt(t.stmt());
|
|
}
|
|
}
|
|
// Now, add all the variables corresponding to the symbolic input params.
|
|
bufferArgs_.insert(
|
|
bufferArgs_.end(),
|
|
symbolic_shape_args.begin(),
|
|
symbolic_shape_args.end());
|
|
|
|
// Now, add all the variables corresponding to symbolic stride inputs
|
|
bufferArgs_.insert(
|
|
bufferArgs_.end(),
|
|
symbolic_stride_args.begin(),
|
|
symbolic_stride_args.end());
|
|
|
|
return block;
|
|
}
|
|
|
|
void TensorExprKernel::deduceMemoryLayoutPolicy() {
|
|
// If the tensor is channels-last contiguous, the preferred memory layout
|
|
// propagation policy is to use channels-last. Otherwise, the preferred policy
|
|
// is to use contiguous.
|
|
auto _prefer_symbolic_mem =
|
|
[](const torch::jit::Value* val,
|
|
const std::vector<torch::jit::StrideInput>& stride_desc_vec) {
|
|
TORCH_INTERNAL_ASSERT(!stride_desc_vec.empty());
|
|
// Has symbolic stride information
|
|
auto cur_stride_desc = stride_desc_vec[0];
|
|
return (cur_stride_desc ==
|
|
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST)
|
|
? MemoryLayoutPolicy::kChannelsLastNdContiguous
|
|
: MemoryLayoutPolicy::kContiguous;
|
|
};
|
|
|
|
auto _prefer_static_mem = [](const torch::jit::Value* val) {
|
|
// No shape info is present in the graph
|
|
TORCH_INTERNAL_ASSERT(
|
|
val->isCompleteTensor(),
|
|
buildErrorMessage(val->debugName() + " is not a complete tensor."));
|
|
const auto& tt = val->type()->expect<TensorType>();
|
|
const auto sizes = *tt->sizes().concrete_sizes();
|
|
const auto strides = *tt->strides().concrete_sizes();
|
|
return (c10::is_channels_last_strides_2d(sizes, strides))
|
|
? MemoryLayoutPolicy::kChannelsLastNdContiguous
|
|
: MemoryLayoutPolicy::kContiguous;
|
|
};
|
|
|
|
// Filter out the tensor from the graph inputs and outputs to
|
|
// deduce the memory layout propagation policy
|
|
auto _is_tensor = [](const jit::Value* el) {
|
|
return el->type()->kind() == TypeKind::TensorType;
|
|
};
|
|
std::vector<torch::jit::Value*> graph_io_tensors;
|
|
std::copy_if(
|
|
graph_->inputs().begin(),
|
|
graph_->inputs().end(),
|
|
std::back_inserter(graph_io_tensors),
|
|
_is_tensor);
|
|
std::copy_if(
|
|
graph_->outputs().begin(),
|
|
graph_->outputs().end(),
|
|
std::back_inserter(graph_io_tensors),
|
|
_is_tensor);
|
|
// std::all_of returns true if the range is empty. But we prefer to keep
|
|
// the original memory layout propagation policy for this case. So we
|
|
// check whether the range is empty.
|
|
auto prefer_channels_last = (!graph_io_tensors.empty());
|
|
for (auto el : graph_io_tensors) {
|
|
auto is_complete = el->isCompleteTensor();
|
|
auto is_symbolic = symbolic_strides_.count(el);
|
|
|
|
auto preferred_mem_layout = is_complete
|
|
? _prefer_static_mem(el)
|
|
: (is_symbolic ? _prefer_symbolic_mem(el, symbolic_strides_[el])
|
|
: MemoryLayoutPolicy::kContiguous);
|
|
if (preferred_mem_layout != MemoryLayoutPolicy::kChannelsLastNdContiguous) {
|
|
prefer_channels_last = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// If the memory layout of all the input and outputs is channels-last
|
|
// contiguous, the propagated memory layout should be channels-last.
|
|
// Otherwise, the propagated memory layout is contiguous which is as
|
|
// same as current situation.
|
|
memory_layout_policy_ = prefer_channels_last
|
|
? MemoryLayoutPolicy::kChannelsLastNdContiguous
|
|
: MemoryLayoutPolicy::kContiguous;
|
|
}
|
|
|
|
void TensorExprKernel::optimizeOwningGraph() {
|
|
GRAPH_DUMP("TensorExprKernel graph (Before graph optimization):", graph_);
|
|
|
|
// We may manipulate output pointers in graph manipulation. So we store the
|
|
// original outputs for symbolic strides information synchronization
|
|
auto _orignal_graph_outputs = graph_->outputs().vec();
|
|
|
|
// Get the graph device information first. The graph optimization
|
|
// might be device specific.
|
|
device_ = *pickDeviceType(graph_);
|
|
|
|
// Determine the propagated memory layout
|
|
deduceMemoryLayoutPolicy();
|
|
|
|
// Fuse Conv with Eltwise Op
|
|
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph_);
|
|
FuseConvWithEltwise(graph_);
|
|
|
|
// Optimize the concatenation
|
|
OptimizeCat(graph_);
|
|
|
|
// Synchronize the symbolic strides information
|
|
auto graph_outputs = graph_->outputs();
|
|
TORCH_INTERNAL_ASSERT(graph_outputs.size() == _orignal_graph_outputs.size());
|
|
for (int i : c10::irange(graph_outputs.size())) {
|
|
auto el_orig = _orignal_graph_outputs.at(i);
|
|
auto el_new = graph_outputs.at(i);
|
|
if (symbolic_strides_.count(el_orig) && (el_orig != el_new)) {
|
|
symbolic_strides_[el_new] = symbolic_strides_[el_orig];
|
|
symbolic_strides_.erase(el_orig);
|
|
}
|
|
}
|
|
|
|
GRAPH_DUMP("TensorExprKernel graph (After graph optimization):", graph_);
|
|
}
|
|
|
|
void TensorExprKernel::compile() {
|
|
GRAPH_DUMP("TensorExprKernel graph:", graph_);
|
|
|
|
has_symbolic_shapes_ = !symbolic_shape_inputs_.empty();
|
|
nInputs_ = graph_->inputs().size();
|
|
nOutputs_ = graph_->outputs().size();
|
|
genInputDebugNames();
|
|
|
|
// Bind inputs to buffers.
|
|
auto block = bindAllInputs();
|
|
|
|
// Bind nodes to tensor compute expressions.
|
|
for (auto const& n : graph_->nodes()) {
|
|
if (n->kind() == prim::ListConstruct) {
|
|
continue;
|
|
} else if (n->kind() == prim::Constant) {
|
|
bindConstant(n->output());
|
|
continue;
|
|
} else {
|
|
for (auto const& output : n->outputs()) {
|
|
if (output->hasUses()) {
|
|
Tensor t = computeValue(output);
|
|
|
|
// If there are for-loops before ExternalCall as follows,
|
|
// stmt1: for:
|
|
// stmt2 for:
|
|
// stmt3: ExternalCall
|
|
// the for-loops would not be parallelized. So we mark the
|
|
// buf args of ExternalCall as to be parallelized to make sure
|
|
// its previous loop still could be parallelized.
|
|
if (to<ExternalCall>(t.stmt())) {
|
|
auto _external_call = to<ExternalCall>(t.stmt());
|
|
for (const auto& _buf : _external_call->buf_args()) {
|
|
bufsToBeParallelized_.insert(_buf);
|
|
}
|
|
}
|
|
|
|
if (output->type()->cast<TensorType>()) {
|
|
// Value is tensor
|
|
if (t.buf()) {
|
|
bufs_.emplace(output, t.buf());
|
|
}
|
|
block->append_stmt(t.stmt());
|
|
} else {
|
|
// Value is scalar
|
|
//
|
|
// We represent scalar computations in TE with a pair of statements:
|
|
// Let val = <compute_expression>
|
|
// Store(buf_for_scalar[0], val)
|
|
//
|
|
// Subsequent computations will use val when they refer to the
|
|
// given value, and the buffer will be used if we need to return
|
|
// the computed value as an output of the kernel. If this is not an
|
|
// output, the store will be removed later by DCE.
|
|
//
|
|
// NB: NNC's lowering functions return Tensor, which is a pair
|
|
// <Buf, Stmt>, but here we also need Var. How can we obtain all of
|
|
// Var, Buf, and Stmt?
|
|
// We use the following trick: the lowering function creates the
|
|
// Let-stmt and a "fake" buffer, whose only purpose is to hold the
|
|
// Var. Then outside the lowering function (namely, right here) we
|
|
// generate the store and the actual buffer.
|
|
VarPtr v = t.buf()->base_handle();
|
|
scalars_[output] = VarHandle(v);
|
|
block->append_stmt(t.stmt());
|
|
std::vector<ExprPtr> dims;
|
|
BufHandle buf(
|
|
"scalar_" + sanitizeName(output->debugName()), {}, v->dtype());
|
|
StmtPtr store = Store::make(buf, {}, ExprHandle(v));
|
|
block->append_stmt(store);
|
|
bufs_.emplace(output, buf.node());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (hasRandom_ && hasBroadcast_) {
|
|
throw std::runtime_error(
|
|
"Cannot support broadcast and random within one kernel");
|
|
}
|
|
}
|
|
|
|
// Move output operands from `bufs_` to `bufOutputs_`
|
|
for (auto i : c10::irange(graph_->outputs().size())) {
|
|
auto& output = graph_->outputs().at(i);
|
|
if (!bufs_.count(output)) {
|
|
throw malformed_input("cannot find output Tensor");
|
|
}
|
|
if (!output->type()->cast<TensorType>()) {
|
|
// Scalar outputs are represented as 0-dim buffers.
|
|
bufOutputs_.insert(bufs_.at(output));
|
|
bufsToBeParallelized_.insert(bufs_.at(output));
|
|
bufferArgs_.emplace_back(BufHandle(bufs_.at(output)));
|
|
tensorOutputTensorOptions_.emplace_back(
|
|
c10::TensorOptions(tensorType(bufs_.at(output))).device(device_));
|
|
tensorOutputSizes_.emplace_back();
|
|
tensorOutputStrides_.emplace_back();
|
|
isOutputScalar_.push_back(true);
|
|
bufs_.erase(output);
|
|
continue;
|
|
}
|
|
|
|
const auto& tt = output->type()->expect<TensorType>();
|
|
if (has_symbolic_shapes_) {
|
|
auto sizes = sizesFromSymbolicShape(tt->symbolic_sizes());
|
|
tensorOutputSymbolicSizes_.push_back(sizes);
|
|
TORCH_INTERNAL_ASSERT(symbolic_strides_.count(output));
|
|
auto stride_desc_vec = symbolic_strides_[output];
|
|
TORCH_INTERNAL_ASSERT(stride_desc_vec.size() == 1);
|
|
auto stride_desc = stride_desc_vec[0];
|
|
tensorOutputStrideDesc_.push_back(stride_desc);
|
|
Tensor properly_strided_output =
|
|
convertSymbolicOutputToCorrectStrides(output);
|
|
if (properly_strided_output.stmt()) {
|
|
block->append_stmt(properly_strided_output.stmt());
|
|
}
|
|
bufs_[output] = properly_strided_output.buf();
|
|
} else {
|
|
// The "strided" tensor will be incorrect if used in NNC,
|
|
// since NNC views it as contiguous. Only convert it to the right
|
|
// strides at the end of the kernel (if already contiguous it's a no-op)
|
|
Tensor properly_strided_output =
|
|
convertStaticShapeOutputToCorrectStrides(output);
|
|
if (properly_strided_output.stmt()) {
|
|
block->append_stmt(properly_strided_output.stmt());
|
|
}
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
bufs_[output] = properly_strided_output.buf();
|
|
auto sizes = *tt->sizes().concrete_sizes();
|
|
tensorOutputSizes_.push_back(sizes);
|
|
auto strides = tt->strides().concrete_sizes();
|
|
|
|
// If the tensor is not dense or overlaps, we have
|
|
// no way of matching the profiled striding
|
|
if (strides && denseAndNonOverlapping(sizes, *strides)) {
|
|
tensorOutputStrides_.push_back(*strides);
|
|
} else {
|
|
tensorOutputStrides_.push_back(TensorType::contiguousStridesOf(sizes));
|
|
}
|
|
}
|
|
|
|
bufOutputs_.insert(bufs_.at(output));
|
|
bufsToBeParallelized_.insert(bufs_.at(output));
|
|
bufferArgs_.emplace_back(BufHandle(bufs_.at(output)));
|
|
tensorOutputTensorOptions_.emplace_back(
|
|
c10::TensorOptions(tensorType(bufs_.at(output))).device(device_));
|
|
isOutputScalar_.push_back(false);
|
|
bufs_.erase(output);
|
|
}
|
|
|
|
BackendType backendType = inferBackendTypeFromDevice(device_);
|
|
stmt_ = transformLoops(backendType, block);
|
|
|
|
for (const auto& c : constants_) {
|
|
bufferArgs_.emplace_back(BufHandle(c.buf));
|
|
}
|
|
|
|
if (has_symbolic_shapes_) {
|
|
tensorOutputSizes_.resize(bufOutputs_.size());
|
|
tensorOutputStrides_.resize(bufOutputs_.size());
|
|
}
|
|
|
|
// Generate code.
|
|
codegen_ = CreateCodeGen(
|
|
getCodeGenName(backendType),
|
|
stmt_,
|
|
bufferArgs_,
|
|
device_,
|
|
kernel_func_name_);
|
|
}
|
|
|
|
void TensorExprKernel::recompile() {
|
|
codegen_ = CreateCodeGen(
|
|
"llvm_codegen", stmt_, bufferArgs_, device_, kernel_func_name_);
|
|
}
|
|
|
|
TensorExprKernel::TensorExprKernel(
|
|
const std::shared_ptr<Graph>& subgraph,
|
|
const std::string& kernel_func_name,
|
|
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings,
|
|
std::vector<int64_t> symbolic_shape_inputs,
|
|
bool pre_alloc /*= false*/,
|
|
std::unordered_map<
|
|
const torch::jit::Value*,
|
|
std::vector<torch::jit::StrideInput>> symbolic_strides)
|
|
: graph_(subgraph),
|
|
code_(subgraph, ""),
|
|
symbolic_shape_inputs_(std::move(symbolic_shape_inputs)),
|
|
custom_lowerings_(std::move(custom_lowerings)),
|
|
pre_alloc_(pre_alloc),
|
|
kernel_func_name_(kernel_func_name),
|
|
symbolic_strides_(std::move(symbolic_strides)) {
|
|
optimizeOwningGraph();
|
|
|
|
allow_fallback_ = fallbackAllowed();
|
|
|
|
if (!allow_fallback_) {
|
|
compile();
|
|
return;
|
|
}
|
|
|
|
use_fallback_ = fallbackEnforced();
|
|
if (use_fallback_) {
|
|
return;
|
|
}
|
|
|
|
try {
|
|
compile();
|
|
} catch (...) {
|
|
use_fallback_ = true;
|
|
}
|
|
}
|
|
|
|
void TensorExprKernel::run(Stack& stack) const {
|
|
if (!use_fallback_ && !allow_fallback_) {
|
|
runKernel(stack);
|
|
} else if (!use_fallback_ && allow_fallback_) {
|
|
try {
|
|
runKernel(stack);
|
|
} catch (...) {
|
|
fallback(stack);
|
|
}
|
|
} else {
|
|
fallback(stack);
|
|
}
|
|
}
|
|
|
|
void TensorExprKernel::getStaticOutputSizesAndStrides(
|
|
const at::ArrayRef<IValue>& inputs,
|
|
std::vector<std::vector<int64_t>>* sizes,
|
|
std::vector<std::vector<int64_t>>* strides) const {
|
|
TORCH_INTERNAL_ASSERT(has_symbolic_shapes_);
|
|
// If there are symbolic shapes, then the output tensor size wouldn't have
|
|
// been computed at compile time. That has to be done here by using the
|
|
// symbolic shape input params passed in to this call.
|
|
TORCH_INTERNAL_ASSERT(
|
|
tensorOutputSymbolicSizes_.size() == bufOutputs_.size());
|
|
|
|
TORCH_INTERNAL_ASSERT(sizes);
|
|
TORCH_INTERNAL_ASSERT(strides);
|
|
*sizes = tensorOutputSizes_;
|
|
*strides = tensorOutputStrides_;
|
|
auto& static_sizes = *sizes;
|
|
auto& static_strides = *strides;
|
|
for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
|
|
static_sizes[i].clear();
|
|
for (auto t : tensorOutputSymbolicSizes_[i]) {
|
|
if (t.AsNode<LongImm>()) {
|
|
static_sizes[i].emplace_back(immediateAs<int64_t>(t.node()));
|
|
} else {
|
|
auto input_pos = shapeSymbolInputPos_.at(t.node());
|
|
TORCH_INTERNAL_ASSERT(input_pos < inputs.size());
|
|
TORCH_INTERNAL_ASSERT(inputs[input_pos].isInt());
|
|
static_sizes[i].emplace_back(inputs[input_pos].toInt());
|
|
}
|
|
}
|
|
|
|
if (tensorOutputStrideDesc_[i] == torch::jit::StrideInput::TENSOR_CONT) {
|
|
static_strides[i] = TensorType::contiguousStridesOf(static_sizes[i]);
|
|
|
|
} else if (
|
|
tensorOutputStrideDesc_[i] ==
|
|
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST) {
|
|
static_strides[i] = at::get_channels_last_strides_2d(static_sizes[i]);
|
|
|
|
} else {
|
|
std::string output_desc = toString(tensorOutputStrideDesc_[i]);
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Expected contiguous or channels last, got ", output_desc);
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<CodeGen::CallArg> TensorExprKernel::prepareRunArgs(
|
|
const at::ArrayRef<IValue>& inputs,
|
|
std::vector<at::Tensor>& outputs) const {
|
|
// TODO: preallocate `runArgs` during compilation and fill in values where
|
|
// possible (e.g. for constant tensors)
|
|
std::vector<CodeGen::CallArg> runArgs;
|
|
runArgs.reserve(
|
|
inputs.size() + input_stride_args_.size() + bufOutputs_.size());
|
|
|
|
for (auto& input : inputs) {
|
|
if (input.isInt()) {
|
|
runArgs.emplace_back(input.toInt());
|
|
} else if (input.isBool()) {
|
|
runArgs.emplace_back(input.toBool());
|
|
} else if (input.isDouble()) {
|
|
runArgs.emplace_back(input.toDouble());
|
|
} else if (input.isTensor()) {
|
|
runArgs.emplace_back(input.toTensor().data_ptr());
|
|
}
|
|
}
|
|
|
|
if (has_symbolic_shapes_) {
|
|
std::vector<std::vector<int64_t>> static_sizes;
|
|
std::vector<std::vector<int64_t>> static_strides;
|
|
getStaticOutputSizesAndStrides(inputs, &static_sizes, &static_strides);
|
|
|
|
// add stride args
|
|
for (const auto& input_stride_arg : input_stride_args_) {
|
|
runArgs.emplace_back(
|
|
inputs[input_stride_arg.first].toTensor().strides().at(
|
|
input_stride_arg.second));
|
|
}
|
|
|
|
for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
|
|
auto const& opts = tensorOutputTensorOptions_[i];
|
|
outputs.emplace_back(codegen_->empty_strided(
|
|
static_sizes[i],
|
|
static_strides[i],
|
|
opts.dtype,
|
|
opts.layout,
|
|
opts.device,
|
|
opts.pinned_memory));
|
|
runArgs.emplace_back(outputs.back().data_ptr());
|
|
}
|
|
} else {
|
|
for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
|
|
auto const& opts = tensorOutputTensorOptions_[i];
|
|
outputs.emplace_back(codegen_->empty_strided(
|
|
tensorOutputSizes_[i],
|
|
tensorOutputStrides_[i],
|
|
opts.dtype,
|
|
opts.layout,
|
|
opts.device,
|
|
opts.pinned_memory));
|
|
runArgs.emplace_back(outputs.back().data_ptr());
|
|
}
|
|
}
|
|
|
|
for (const auto& c : constants_) {
|
|
runArgs.emplace_back(c.ptr);
|
|
}
|
|
|
|
return runArgs;
|
|
}
|
|
|
|
StmtPtr TensorExprKernel::getCodeGenStmt() {
|
|
return codegen_->stmt();
|
|
}
|
|
|
|
void TensorExprKernel::runKernel(Stack& stack) const {
|
|
// Set up arguments (inputs, then outputs) for kernel call.
|
|
auto inputs = last(stack, nInputs_);
|
|
std::vector<at::Tensor> outputs;
|
|
|
|
std::vector<CodeGen::CallArg> runArgs = prepareRunArgs(inputs, outputs);
|
|
|
|
// Call the kernel.
|
|
codegen_->call(runArgs);
|
|
|
|
// Update the stack.
|
|
drop(stack, nInputs_);
|
|
|
|
int64_t idx = 0;
|
|
for (auto& o : outputs) {
|
|
if (isOutputScalar_[idx++]) {
|
|
// Scalar outputs are returned as 0-dim tensors, we need to extract the
|
|
// scalar value from them
|
|
push_one(stack, o.item());
|
|
} else {
|
|
push_one(stack, std::move(o));
|
|
}
|
|
}
|
|
}
|
|
|
|
void TensorExprKernel::runFast(
|
|
const std::vector<void*>& inputs,
|
|
const std::vector<void*>& outputs) const {
|
|
std::vector<void*> args(inputs);
|
|
args.reserve(inputs.size() + outputs.size() + constants_.size());
|
|
args.insert(args.end(), outputs.begin(), outputs.end());
|
|
|
|
// TODO: we can consider preallocating and pre-filling the args vector.
|
|
for (const auto& c : constants_) {
|
|
args.push_back(c.ptr);
|
|
}
|
|
|
|
// Call the kernel.
|
|
codegen_->call_raw(args);
|
|
}
|
|
|
|
void TensorExprKernel::runWithAllocatedOutputs(Stack& stack) const {
|
|
TORCH_INTERNAL_ASSERT(
|
|
device_ == at::kCPU,
|
|
"Pre-allocated output tensors are supported only on CPUs.");
|
|
std::vector<void*> args;
|
|
args.reserve(nInputs_ + nOutputs_ + constants_.size());
|
|
|
|
// stack has inputs on the top and outputs right below them.
|
|
auto stack_ivals = last(stack, nOutputs_ + nInputs_);
|
|
auto stack_outputs = stack_ivals.slice(0, nOutputs_);
|
|
auto stack_inputs = stack_ivals.slice(nOutputs_);
|
|
|
|
std::vector<int64_t> int_inputs(nInputs_);
|
|
for (auto i : c10::irange(nInputs_)) {
|
|
auto inp = stack_inputs[i];
|
|
if (inp.isInt()) {
|
|
int_inputs[i] = inp.toInt();
|
|
args.emplace_back(&int_inputs[i]);
|
|
} else if (inp.isTensor()) {
|
|
args.emplace_back(inp.toTensor().data_ptr());
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Unhandled input type while calling TensorExprKernel");
|
|
}
|
|
}
|
|
|
|
std::vector<int64_t> stride_values(input_stride_args_.size());
|
|
if (has_symbolic_shapes_) {
|
|
std::vector<std::vector<int64_t>> static_sizes;
|
|
std::vector<std::vector<int64_t>> static_strides;
|
|
getStaticOutputSizesAndStrides(
|
|
stack_inputs, &static_sizes, &static_strides);
|
|
|
|
// add stride args
|
|
for (auto idx : c10::irange(input_stride_args_.size())) {
|
|
const auto& input_stride_arg = input_stride_args_[idx];
|
|
stride_values[idx] =
|
|
stack_inputs[input_stride_arg.first].toTensor().strides().at(
|
|
input_stride_arg.second);
|
|
args.emplace_back(&stride_values[idx]);
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
nOutputs_ == static_cast<int64_t>(bufOutputs_.size()));
|
|
for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
|
|
auto& out = stack_outputs[i].toTensor();
|
|
// This has only been tested on CPUs.
|
|
// TODO: Test on GPUs.
|
|
out.resize_(static_sizes[i]);
|
|
args.emplace_back(out.data_ptr());
|
|
}
|
|
} else {
|
|
for (auto i : c10::irange(nOutputs_)) {
|
|
args.emplace_back(stack_outputs[i].toTensor().data_ptr());
|
|
}
|
|
}
|
|
|
|
for (const auto& c : constants_) {
|
|
args.emplace_back(c.ptr);
|
|
}
|
|
|
|
// Call the kernel.
|
|
codegen_->call_raw(args);
|
|
|
|
// Remove the inputs from the stack. The outputs are already below the inputs
|
|
// in the stack.
|
|
drop(stack, nInputs_);
|
|
}
|