Files
pytorch/torch/csrc/jit/codegen/onednn/graph_helper.cpp
2025-07-02 22:55:29 +00:00

616 lines
25 KiB
C++

#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <ATen/core/functional.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
namespace torch::jit::fuser::onednn {
using opkind = dnnl::graph::op::kind;
static void fixConvOptionalBias(Node* node) {
if (node->namedInput("bias")->mustNotBeNone() == false) {
// Replace non-existent optional bias with const None
auto g = node->owningGraph();
auto n = g->createNone();
auto v = n->insertBefore(node)->output();
node->replaceInput(2, v);
}
}
static std::optional<size_t> getDimensions(Value* v) {
if (v->type()->isSubtypeOf(TensorType::get())) {
return v->type()->cast<TensorType>()->sizes().size();
} else {
return std::nullopt;
}
}
// PyTorch ops that can't otherwise be mapped to oneDNN Graph ops are mapped as
// Wildcards instead. They make the integration code with PyTorch simpler by
// passing every op to the oneDNN Graph library in the add_op call -
// no need to check beforehand whether the op is supported by oneDNN Graph or
// not oneDNN Graph ops separated by wildcards don't end up in the same
// partition.
static Operator makeWildcardOp(Node* node) {
auto o = Operator(node, opkind::Wildcard);
// wildcard op contains only topology info
for (size_t i = 0; i < node->inputs().size(); i++) {
o.setInput(0, i);
}
for (size_t i = 0; i < node->outputs().size(); i++) {
o.setOutput(i);
}
return o;
}
// If we don't meet a certain condition to map a PyTorch op to a oneDNN Graph
// op, then we create a wildcard op corresponding to that PyTorch op instead.
#define REQUIRE(cond) \
if (!(cond)) { \
GRAPH_DEBUG("Unsupported condition " #cond "\n"); \
return makeWildcardOp(node); \
}
Operator LlgaGraphHelper::makeEltwiseOp(Node* node, opkind kind) {
return Operator(node, kind).setInput(0).setOutput(dnnl_graph_, 0);
}
Operator LlgaGraphHelper::makeBinaryOp(Node* node, opkind kind) {
REQUIRE(
node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
node->input(1)->type()->isSubtypeOf(TensorType::get()))
return Operator(node, kind).setInput(0, 1).setOutput(dnnl_graph_, 0);
}
// Map a PyTorch op to its corresponding oneDNN Graph op.
// If mapping isn't possible, then create a wildcard op instead.
// The mapping is done as per oneDNN Graph op schema defined in
// third_party/ideep/mkl-dnn/src/interface/op_def.hpp.
Operator LlgaGraphHelper::createOperator(Node* node) {
auto nodeKind = node->kind();
// we're using an if-else clause instead of a switch statement
// because we would soon be adding custom ops with function schemas.
// We would have to use Symbol::fromQualString at that time anyway,
// but we are okay with this choice, since this code is not in the hot-path.
if (nodeKind == Symbol::fromQualString("aten::conv2d")) {
fixConvOptionalBias(node);
return Operator(node, opkind::Convolution)
.setInput(0, 1, 2)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4)
.setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4)
.setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5)
.setAttr(dnnl::graph::op::attr::groups, Operator::Int, 6)
.setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX"))
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
} else if (
(nodeKind == Symbol::fromQualString("aten::_convolution")) ||
(nodeKind == Symbol::fromQualString("aten::convolution"))) {
bool transposed = toIValue(node->namedInput("transposed"))->toBool();
REQUIRE(!transposed);
return Operator(node, opkind::Convolution)
.setInput(0, 1, 2)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4)
.setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4)
.setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5)
.setAttr(dnnl::graph::op::attr::groups, Operator::Int, 8)
.setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX"))
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
} else if (nodeKind == Symbol::fromQualString("aten::batch_norm")) {
auto training = toIValue(node->namedInput("training"));
REQUIRE(training.has_value()); // cannot get training status in script mode
if (!training->toBool()) {
return Operator(node, opkind::BatchNormInference)
.setInput(0, 1, 2, 3, 4)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 7)
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
}
} else if (nodeKind == Symbol::fromQualString("aten::layer_norm")) {
auto normalized_shape = toIValue(node->namedInput("normalized_shape"));
REQUIRE(normalized_shape->toIntList().size() == 1);
return Operator(node, opkind::LayerNorm)
.setInput(0, 2, 3)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 4)
.setAttr(dnnl::graph::op::attr::keep_stats, false);
} else if (nodeKind == Symbol::fromQualString("aten::addmm")) {
auto alpha = toIValue(node->namedInput("alpha"));
auto beta = toIValue(node->namedInput("beta"));
if (alpha.has_value() && beta.has_value()) {
if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 1.0)) {
return Operator(node, opkind::MatMul)
.setInput(1, 2, 0)
.setOutput(dnnl_graph_, 0);
} else if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 0.0)) {
return Operator(node, opkind::MatMul)
.setInput(1, 2)
.setOutput(dnnl_graph_, 0);
}
}
} else if (nodeKind == Symbol::fromQualString("aten::add"))
return makeBinaryOp(node, opkind::Add);
else if (nodeKind == Symbol::fromQualString("aten::mul"))
return makeBinaryOp(node, opkind::Multiply);
else if (nodeKind == Symbol::fromQualString("aten::div"))
return makeBinaryOp(node, opkind::Divide);
else if (nodeKind == Symbol::fromQualString("aten::tanh"))
return makeEltwiseOp(node, opkind::Tanh);
else if (nodeKind == Symbol::fromQualString("aten::relu"))
return makeEltwiseOp(node, opkind::ReLU);
else if (nodeKind == Symbol::fromQualString("aten::elu"))
return makeEltwiseOp(node, opkind::Elu)
.setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1);
else if (nodeKind == Symbol::fromQualString("aten::sigmoid"))
return makeEltwiseOp(node, opkind::Sigmoid);
else if (nodeKind == Symbol::fromQualString("aten::gelu"))
return makeEltwiseOp(node, opkind::GELU);
else if (nodeKind == Symbol::fromQualString("aten::round"))
return makeEltwiseOp(node, opkind::Round);
else if (nodeKind == Symbol::fromQualString("aten::exp"))
return makeEltwiseOp(node, opkind::Exp);
else if (nodeKind == Symbol::fromQualString("aten::sqrt"))
return makeEltwiseOp(node, opkind::Sqrt);
else if (nodeKind == Symbol::fromQualString("aten::abs"))
return makeEltwiseOp(node, opkind::Abs);
else if (nodeKind == Symbol::fromQualString("aten::square"))
return makeEltwiseOp(node, opkind::Square);
else if (nodeKind == Symbol::fromQualString("aten::clamp")) {
// PyTorch API already checks that both min & max are not None.
// But we can check it nevertheless.
auto clamp_min = toIValue(node->input(1));
auto clamp_max = toIValue(node->input(2));
REQUIRE(!(clamp_max->isNone() && clamp_min->isNone()));
auto clamp_min_value = (clamp_min->isNone())
? -std::numeric_limits<float>::infinity()
: Operator::ScalarToFloat(node, 1);
auto clamp_max_value = (clamp_max->isNone())
? std::numeric_limits<float>::infinity()
: Operator::ScalarToFloat(node, 2);
return makeEltwiseOp(node, opkind::Clamp)
.setAttr(dnnl::graph::op::attr::min, clamp_min_value)
.setAttr(dnnl::graph::op::attr::max, clamp_max_value);
} else if (nodeKind == Symbol::fromQualString("aten::hardtanh")) {
return makeEltwiseOp(node, opkind::Clamp)
.setAttr(dnnl::graph::op::attr::min, Operator::ScalarToFloat, 1)
.setAttr(dnnl::graph::op::attr::max, Operator::ScalarToFloat, 2);
} else if (nodeKind == Symbol::fromQualString("aten::hardswish"))
return makeEltwiseOp(node, opkind::HardSwish);
else if (nodeKind == Symbol::fromQualString("aten::log"))
return makeEltwiseOp(node, opkind::Log);
else if (nodeKind == Symbol::fromQualString("aten::leaky_relu")) {
return makeEltwiseOp(node, opkind::LeakyReLU)
.setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1);
} else if (nodeKind == Symbol::fromQualString("aten::relu6")) {
return makeEltwiseOp(node, opkind::Clamp)
.setAttr(dnnl::graph::op::attr::min, 0.f)
.setAttr(dnnl::graph::op::attr::max, 6.f);
} else if (
(nodeKind == Symbol::fromQualString("aten::softmax")) ||
(nodeKind == Symbol::fromQualString("aten::_softmax"))) {
auto axis = toIValue(node->namedInput("dim"))->toInt();
return Operator(node, opkind::SoftMax)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::axis, axis);
} else if (nodeKind == Symbol::fromQualString("aten::_log_softmax")) {
auto axis = toIValue(node->namedInput("dim"))->toInt();
return Operator(node, opkind::LogSoftmax)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::axis, axis);
} else if (nodeKind == Symbol::fromQualString("aten::cat")) {
auto o = Operator(node, opkind::Concat);
REQUIRE(node->namedInput("tensors")->node()->kind() == prim::ListConstruct);
REQUIRE(node->namedInput("tensors")->uses().size() == 1);
REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant);
// aten::cat needs a special handling since it takes a Tensor[] as input.
// We set the inputs of ListConstruct as the inputs of cat.
//
// Pytorch IR: LLGA sees:
// %a %b %c %dim %a %b %c
// \ | / | \ | /
// prim::ListConstruct prim::Constant llga::Concat[axis=%dim]
// \ /
// aten::cat
auto listConstruct = node->input(0)->node();
for (auto input : listConstruct->inputs())
o.setInputValue(input);
return o.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::axis, Operator::Int, 1);
} else if (
(nodeKind == Symbol::fromQualString("aten::max_pool2d")) ||
(nodeKind == Symbol::fromQualString("aten::max_pool2d_with_indices"))) {
// Currently, LLGA lacks support to create indices mask.
// Once it's supported, max_pool2d_with_indices should be mapped differently
REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
auto rounding_type =
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
return Operator(node, opkind::MaxPool)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1)
.setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2)
.setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 4)
.setAttr(
dnnl::graph::op::attr::rounding_type, std::string(rounding_type))
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
} else if (nodeKind == Symbol::fromQualString("aten::avg_pool2d")) {
// TODO: do we need add checks for all Constants?
REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
auto rounding_type =
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
auto divisor_override = toIValue(node->namedInput("divisor_override"));
REQUIRE(divisor_override->isNone());
return Operator(node, opkind::AvgPool)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1)
.setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2)
.setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::exclude_pad, !Operator::Bool(node, 5))
.setAttr(
dnnl::graph::op::attr::rounding_type, std::string(rounding_type))
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
} else if (nodeKind == Symbol::fromQualString("aten::matmul")) {
auto dim0 = getDimensions(node->namedInput("self")).value_or(-1);
auto dim1 = getDimensions(node->namedInput("other")).value_or(-1);
// TODO: support all shape combinations
REQUIRE(
(dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) ||
(dim0 == 3 && dim1 == 2));
return Operator(node, opkind::MatMul)
.setInput(0, 1)
.setOutput(dnnl_graph_, 0);
} // fall through
else if (nodeKind == Symbol::fromQualString("aten::mm")) {
return Operator(node, opkind::MatMul)
.setInput(0, 1)
.setOutput(dnnl_graph_, 0);
} else if (nodeKind == Symbol::fromQualString("aten::bmm")) {
return Operator(node, opkind::MatMul)
.setInput(0, 1)
.setOutput(dnnl_graph_, 0);
} else if (nodeKind == Symbol::fromQualString("aten::linear")) {
return Operator(node, opkind::MatMul)
.setInput(0, 1, 2)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::transpose_b, true);
} else if (nodeKind == Symbol::fromQualString("aten::permute")) {
REQUIRE(aliasDb_->hasInputWriters(node) == false);
return Operator(node, opkind::StaticTranspose)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(
dnnl::graph::op::attr::order,
toIValue(node->namedInput("dims"))->toIntVector());
} else if (nodeKind == Symbol::fromQualString("aten::contiguous")) {
// Contiguous should only be mapped to oneDNN Graph if the destination
// memory-layout is different than the source memory-format
// Strides would be different, but shape would be same
auto typeOfInput = node->input(0)->type()->expect<TensorType>();
auto typeOfOutput = node->output(0)->type()->expect<TensorType>();
auto inputStrides = typeOfInput->strides().concrete_sizes();
auto outputStrides = typeOfOutput->strides().concrete_sizes();
REQUIRE(inputStrides != outputStrides);
return Operator(node, opkind::Reorder)
.setInput(0)
.setOutput(dnnl_graph_, 0);
}
GRAPH_DEBUG("Making ", nodeKind.toQualString(), " a wildcard");
return makeWildcardOp(node);
}
static DeviceType inferDeviceFromValue(Value* v) {
auto tt = v->type()->cast<TensorType>();
if (!tt) {
return at::kCPU;
}
auto device = tt->device();
if (!device) {
return at::kCPU;
}
return device->type();
}
static DeviceType inferDevice(const std::shared_ptr<Graph>& graph) {
auto dt = inferDeviceFromValue(graph->inputs()[0]);
TORCH_CHECK(
std::all_of(
graph->inputs().begin(),
graph->inputs().end(),
[dt](Value* v) { return inferDeviceFromValue(v) == dt; }),
"All inputs must have the same deive type");
return dt;
}
static dnnl::engine::kind getLlgaEngineKind(DeviceType type) {
switch (type) {
case DeviceType::CPU:
return dnnl::engine::kind::cpu;
default:
TORCH_CHECK(false, "Not support device type ", type);
}
}
static void mayAddListConstructIntoConcatPartition(
Node* n,
OpPartitionMap& opToOwningPartition) {
// Since prim::ListConstruct is not visible to the LLGA,
// it will not be in any partition returned from partfuseritioning results.
// We need rewrite opToOwningPartition to make the prim::ListConstruct to be
// 'virtually' in the same partition with the aten::cat, so that
// prim::ListConstruct can be fused into the fusion group by graph fuser.
// We emphasize on 'virtually' because get_num_ops() for cat's partition
// would still return 1.
if (n->kind() == aten::cat && opToOwningPartition.has(n)) {
auto listConstrcut = n->namedInput("tensors")->node();
auto partitionId = opToOwningPartition.get(n);
opToOwningPartition.add(listConstrcut, partitionId);
}
}
// Verify that input tensors are compatible with oneDNN Graph.
// Scalars would be converted to 1-D tensors later anyway,
// but they shouldn't be complex-double
// If this check fails, convert op to wildcard
static bool checkInputCompatibility(Node* node) {
auto allInputs = node->inputs();
for (auto input : allInputs) {
c10::IValue inputIValue = toIValue(input);
if (inputIValue.isTensor()) {
const at::Tensor& tensor = inputIValue.toTensor();
if (tensor.device() != at::kCPU) {
return false;
}
auto dtype = tensor.scalar_type();
if ((dtype != at::ScalarType::BFloat16) &&
(dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) {
// We've allowed Long dtype here although oneDNN Graph does not support
// Long dtype because oneDNN Graph will end up not handling the op that
// has an input with Long dtype, so it'd be handled by PyTorch.
return false;
}
} else if (inputIValue.isScalar()) {
if (inputIValue.isComplexDouble()) {
return false;
}
} else if (input->type()->isSubtypeOf(TensorType::get())) {
auto input_typeptr = input->type()->cast<TensorType>();
if (input_typeptr->scalarType().has_value()) {
at::ScalarType dtype = input_typeptr->scalarType().value();
if ((dtype != at::ScalarType::Float) &&
(dtype != at::ScalarType::BFloat16)) {
return false;
}
}
}
}
return true;
}
LlgaGraphHelper::LlgaGraphHelper(
const std::shared_ptr<Graph>& graph,
dnnl::graph::partition::policy policy) {
auto deviceType = inferDevice(graph);
auto engineKind = getLlgaEngineKind(deviceType);
dnnl_graph_ = std::make_unique<dnnl::graph::graph>(engineKind);
aliasDb_ = std::make_unique<torch::jit::AliasDb>(graph);
GRAPH_DEBUG("Constructing LLGA graph");
// TODO: select nodes in top-level block for now
for (auto* node : graph->block()->nodes()) {
auto kindOfNode = node->kind();
GRAPH_DEBUG("Trying to add ", kindOfNode.toQualString());
if (checkInputCompatibility(node)) {
auto op = createOperator(node);
dnnl_graph_->add_op(op.llgaOp());
GRAPH_DEBUG(" Added node ", kindOfNode.toQualString());
} else {
GRAPH_DEBUG("Incompatible inputs for ", kindOfNode.toQualString());
dnnl_graph_->add_op(makeWildcardOp(node).llgaOp());
}
for (Value* input : node->inputs()) {
tensorIdToValue_.emplace(input->unique(), input);
}
}
dnnl_graph_->finalize();
GRAPH_DEBUG("Get Partitions");
std::vector<dnnl::graph::partition> partitions =
dnnl_graph_->get_partitions(policy);
// excluded unsupported Wildcard partitions
for (auto& partition : partitions) {
if (partition.is_supported()) {
partitions_.push_back(partition);
}
}
GRAPH_DEBUG(" Got #partitions: ", partitions_.size());
for (size_t partId = 0; partId < partitions_.size(); partId++) {
for (auto opId : partitions_[partId].get_ops()) {
opToOwningPartition_.add(opId, partId);
}
}
// Scanning the graph again for post processing
for (auto* node : graph->block()->nodes()) {
mayAddListConstructIntoConcatPartition(node, opToOwningPartition_);
}
}
bool LlgaGraphHelper::isLlgaSubgraph(const Node* node) {
return node->hasAttribute(attr::Subgraph) &&
node->kind() == prim::oneDNNFusionGroup;
}
bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) {
TORCH_CHECK(
isLlgaSubgraph(subgraph),
"The consumer node does not contain a subgraph");
if (!shouldConsiderForMerge(toMerge)) {
return false;
}
return opToOwningPartition_.get(toMerge) ==
opToOwningPartition_.get(subgraph);
}
// Except for conv & GEMMs, which should always be handled by oneDNN Graph,
// only use single-op partitions for ops unsupported by NNC, or ops
// that oneDNN executes faster. prim::ListConstruct is an exception, since
// we simply want to fuse it with cat.
static bool isBetterSuitedForLLGA(NodeKind kindOfOp) {
return (
(kindOfOp == aten::layer_norm) || (kindOfOp == aten::avg_pool2d) ||
(kindOfOp == aten::matmul) || (kindOfOp == aten::max_pool2d) ||
(kindOfOp == aten::conv2d) || (kindOfOp == aten::_convolution) ||
(kindOfOp == aten::mm) || (kindOfOp == aten::linear) ||
(kindOfOp == aten::cat) || (kindOfOp == prim::ListConstruct));
}
bool LlgaGraphHelper::checkForSingleOpPartition(Node* node) {
if (opToOwningPartition_.has(node)) {
auto partitionId = opToOwningPartition_.get(node);
if (partitions_[partitionId].get_ops_num() == 1) {
auto kindOfNode = node->kind();
return isBetterSuitedForLLGA(kindOfNode);
} else {
// multi-op partition
return true;
}
} else {
// this op isn't present in any partition
return false;
}
}
bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
// if we're already in the process of merging
if (isLlgaSubgraph(node)) {
return true;
}
return checkForSingleOpPartition(node);
}
Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) {
auto partitionId = opToOwningPartition_.get(n);
GRAPH_DEBUG(
"Creating FusionGroup_", partitionId, " for ", n->kind().toQualString());
auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
n, prim::oneDNNFusionGroup, aliasDb);
opToOwningPartition_.add(group, partitionId);
return group;
}
void LlgaGraphHelper::mergeNodeIntoSubgraph(
Node* toMerge,
Node* subgraphNode,
AliasDb& aliasDb) {
if (isLlgaSubgraph(toMerge)) {
GRAPH_DEBUG(
"Merging ",
toMerge->kind().toQualString(),
"_",
opToOwningPartition_.get(toMerge),
" into ",
subgraphNode->kind().toQualString(),
"_",
opToOwningPartition_.get(subgraphNode));
} else {
GRAPH_DEBUG(
"Merging ",
toMerge->kind().toQualString(),
" into ",
subgraphNode->kind().toQualString(),
"_",
opToOwningPartition_.get(subgraphNode));
}
SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
toMerge, subgraphNode, aliasDb);
}
void LlgaGraphHelper::unmergeIfAnyNodeIsMissing(Node* subgraphNode) {
TORCH_CHECK(isLlgaSubgraph(subgraphNode), "Cannot unmerge a non-LLGA node");
auto partitionId = opToOwningPartition_.get(subgraphNode);
auto expectOpNum = partitions_[partitionId].get_ops_num();
auto actualOpNum = countSupportedOps(subgraphNode->g(attr::Subgraph));
if (expectOpNum != actualOpNum) {
GRAPH_DEBUG(
"Unmerging FusionGroup_",
partitionId,
". Expected ",
expectOpNum,
" ops, but got ",
actualOpNum,
" ops.");
SubgraphUtils::unmergeSubgraph(subgraphNode);
}
}
size_t LlgaGraphHelper::countSupportedOps(
const std::shared_ptr<Graph>& graph) const {
// TODO: count nodes in top-level block for now
size_t cnt = 0;
for (auto* node : graph->block()->nodes()) {
auto nodeKind = node->kind();
if ((nodeKind != prim::Constant) && (nodeKind != prim::ListConstruct)) {
cnt++;
}
}
return cnt;
}
std::vector<dnnl::graph::partition> LlgaGraphHelper::getPartitions() const {
return partitions_;
}
std::map<size_t, Value*> LlgaGraphHelper::getTensorIdToValue() const {
return tensorIdToValue_;
}
LlgaNodeWrapper::LlgaNodeWrapper(const Node* node)
: n(const_cast<Node*>(node)) { // NOLINT
TORCH_CHECK(
LlgaGraphHelper::isLlgaSubgraph(n), "Cannot wrap a non-LLGA fusion node");
}
void LlgaNodeWrapper::setOpaqueLayout(size_t offset) {
const auto num_output = n->is(attr::output_layouts).size();
TORCH_CHECK(
offset < num_output,
"Out of range. (Invalid index ",
offset,
" for attr::output_layouts with size ",
num_output,
")");
auto& layouts =
const_cast<std::vector<int64_t>&>(n->is(attr::output_layouts)); // NOLINT
layouts.at(offset) = OPAQUE_LAYOUT;
}
bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const {
const auto num_output = n->is(attr::output_layouts).size();
TORCH_CHECK(
offset < num_output,
"Out of range. (Invalid index ",
offset,
" for attr::output_layouts with size ",
num_output,
")");
return n->is(attr::output_layouts)[offset] == OPAQUE_LAYOUT;
}
} // namespace torch::jit::fuser::onednn