From 2b79bab029ec387618f6b9724a4027b81fbda33c Mon Sep 17 00:00:00 2001 From: Jie Date: Wed, 4 Mar 2020 20:21:24 -0800 Subject: [PATCH] [CUDA_FUSER] Fork CUDA fuser (#33527) Summary: Separating CUDA fuser from CPU fuser. 1. New node in IR - prim::CudaFusionGroup: This enables the cuda fuser to co-exist along side the old fuser. Allows us to incrementally build and expand cuda fuser. 2. copied FuseGraph optimization passes to CudaFuserGraph: We will re-factor & reuse Chunk/Concat in the old fuser logic, which is handled in the optimization pass at this moment. Unfortunately many code in the pass is tightly binded with the legacy fuser, which makes code sharing difficult. The CudaFusionGraph will support only a subset of operations comparing to legacy fuser (CUDA only). It is registered as a custom pass post fusion via ```torch._C._jit_register_cuda_fuser()``` To have it in effect, you should also turn off fusion on GPU via ```torch._C._jit_override_can_fuse_on_gpu(False)``` 3. We don't have codegen in this PR yet (WIP). Currently we just fall back to the old fuser. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33527 Differential Revision: D20171598 Pulled By: ZolotukhinM fbshipit-source-id: 9a3c0f06f46da7eaa80ae7551c04869f5b03ef71 --- aten/src/ATen/core/interned_strings.h | 1 + caffe2/CMakeLists.txt | 1 + tools/build_variables.bzl | 1 + torch/csrc/jit/ir/alias_analysis.cpp | 1 + torch/csrc/jit/ir/ir.cpp | 1 + torch/csrc/jit/passes/cuda_graph_fuser.cpp | 1231 +++++++++++++++++ torch/csrc/jit/passes/cuda_graph_fuser.h | 17 + .../jit/passes/inline_autodiff_subgraphs.cpp | 3 +- torch/csrc/jit/python/init.cpp | 3 + torch/csrc/jit/python/python_ir.cpp | 3 + torch/csrc/jit/runtime/operator.cpp | 2 + torch/csrc/jit/runtime/register_prim_ops.cpp | 11 + 12 files changed, 1274 insertions(+), 1 deletion(-) create mode 100644 torch/csrc/jit/passes/cuda_graph_fuser.cpp create mode 100644 torch/csrc/jit/passes/cuda_graph_fuser.h diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 68f2e79a4b5b..8f48640b4ed9 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -33,6 +33,7 @@ namespace c10 { _(prim, Eval) \ _(prim, Expand) /* onnx */ \ _(prim, FusionGroup) \ + _(prim, CudaFusionGroup) \ _(prim, DifferentiableGraph) \ _(prim, If) \ _(prim, Jump) /* debug */ \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index c82a541858e4..ee3481045763 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -411,6 +411,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/passes/fixup_trace_scope_blocks.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/inline_fork_wait.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/graph_fuser.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/cuda_graph_fuser.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/guard_elimination.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/inplace_check.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/liveness.cpp diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 976b11e6e01f..9805f5215964 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -121,6 +121,7 @@ libtorch_sources = [ "torch/csrc/jit/passes/erase_number_types.cpp", "torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp", "torch/csrc/jit/passes/graph_fuser.cpp", + "torch/csrc/jit/passes/cuda_graph_fuser.cpp", "torch/csrc/jit/passes/guard_elimination.cpp", "torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp", "torch/csrc/jit/passes/inliner.cpp", diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 5aaaa0bb95bd..8900371bcbf9 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -326,6 +326,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::Loop: return analyzeLoop(node); case prim::FusionGroup: + case prim::CudaFusionGroup: case prim::DifferentiableGraph: return analyzeSubgraph(node); case prim::fork: diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 1540604e1749..90814dd157c3 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -430,6 +430,7 @@ void Node::lint() const { // longer. break; case prim::FusionGroup: + case prim::CudaFusionGroup: checkSameDevice(this); // TODO: Typecheck the parameters g(attr::Subgraph)->lint(); diff --git a/torch/csrc/jit/passes/cuda_graph_fuser.cpp b/torch/csrc/jit/passes/cuda_graph_fuser.cpp new file mode 100644 index 000000000000..57acb9dcf837 --- /dev/null +++ b/torch/csrc/jit/passes/cuda_graph_fuser.cpp @@ -0,0 +1,1231 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { + +namespace { + +// What is a simple mappable operator? It: +// - Has a single tensor output +// - Output and all tensor inputs have the same shape +// - Output and all tensor inputs have the same scalar type +// or all tensor inputs have the same scalar type and +// output is identified in PropagateInputShapes +// - Output and all tensor inputs should be on the same device +// - Produces dense non-overlapping outputs +// Some of these restrictions may be relaxable, but you should +// carefully read the code first, as we rely on these assumptions. +bool isSimpleMap(Node* node) { + static OperatorSet simple_mappable{{ + "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", + + "aten::abs(Tensor self) -> Tensor", + "aten::acos(Tensor self) -> Tensor", + "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::asin(Tensor self) -> Tensor", + "aten::atan(Tensor self) -> Tensor", + "aten::atan2(Tensor self, Tensor other) -> Tensor", + "aten::ceil(Tensor self) -> Tensor", + "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor", + "aten::cos(Tensor self) -> Tensor", + "aten::cosh(Tensor self) -> Tensor", + "aten::div(Tensor self, Tensor other) -> Tensor", + "aten::exp(Tensor self) -> Tensor", + "aten::expm1(Tensor self) -> Tensor", + "aten::erf(Tensor self) -> Tensor", + "aten::erfc(Tensor self) -> Tensor", + "aten::floor(Tensor self) -> Tensor", + "aten::fmod(Tensor self, Tensor other) -> Tensor", + "aten::frac(Tensor self) -> Tensor", + "aten::lgamma(Tensor self) -> Tensor", + "aten::log(Tensor self) -> Tensor", + "aten::log10(Tensor self) -> Tensor", + "aten::log1p(Tensor self) -> Tensor", + "aten::log2(Tensor self) -> Tensor", + "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor", + "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor", + "aten::max(Tensor self, Tensor other) -> Tensor", + "aten::min(Tensor self, Tensor other) -> Tensor", + "aten::mul(Tensor self, Tensor other) -> Tensor", + "aten::neg(Tensor self) -> Tensor", + "aten::pow(Tensor self, Tensor exponent) -> Tensor", + "aten::pow(Tensor self, Scalar exponent) -> Tensor", + "aten::pow(Scalar self, Tensor exponent) -> Tensor", + "aten::reciprocal(Tensor self) -> Tensor", + "aten::relu(Tensor self) -> Tensor", + "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", + "aten::remainder(Tensor self, Tensor other) -> Tensor", + "aten::round(Tensor self) -> Tensor", + "aten::rsqrt(Tensor self) -> Tensor", + "aten::sigmoid(Tensor self) -> Tensor", + "aten::sin(Tensor self) -> Tensor", + "aten::sinh(Tensor self) -> Tensor", + "aten::sqrt(Tensor self) -> Tensor", + "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::tan(Tensor self) -> Tensor", + "aten::rand_like(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", + "aten::tanh(Tensor self) -> Tensor", + "aten::trunc(Tensor self) -> Tensor", + "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", + "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", + "aten::mul(Tensor self, Scalar other) -> Tensor", + "aten::div(Tensor self, Scalar other) -> Tensor", + + "aten::eq(Tensor self, Tensor other) -> Tensor", + "aten::eq(Tensor self, Scalar other) -> Tensor", + "aten::ne(Tensor self, Tensor other) -> Tensor", + "aten::ne(Tensor self, Scalar other) -> Tensor", + "aten::ge(Tensor self, Tensor other) -> Tensor", + "aten::ge(Tensor self, Scalar other) -> Tensor", + "aten::gt(Tensor self, Tensor other) -> Tensor", + "aten::gt(Tensor self, Scalar other) -> Tensor", + "aten::le(Tensor self, Tensor other) -> Tensor", + "aten::le(Tensor self, Scalar other) -> Tensor", + "aten::lt(Tensor self, Tensor other) -> Tensor", + "aten::lt(Tensor self, Scalar other) -> Tensor", + + "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", + "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor", + + "aten::type_as(Tensor self, Tensor other) -> Tensor", + }}; + if (!node->isMemberOf(simple_mappable)) { + return false; + } + for (Value* input : node->inputs()) { + if (input->type()->isSubtypeOf(TensorType::get()) || + input->type()->isSubtypeOf(FloatType::get())) { + continue; + } + if (input->node()->kind() != prim::Constant) { + return false; + } + } + return true; +} + +Value* broadcastSizes(at::ArrayRef sizes) { + AT_ASSERT(!sizes.empty()); + Graph* graph = sizes[0]->owningGraph(); + Node* broadcast_n = + graph->insertNode(graph->create(prim::BroadcastSizes, sizes)); + broadcast_n->output()->setType(ListType::ofInts()); + return broadcast_n->output(); +} + +struct GraphFuser { + using FusionCallback = std::function; + + Block* block_; + std::unique_ptr aliasDb_; + std::shared_ptr graph_; + FusionCallback callback_ = [&](Node* n) { return isFusableDefault(n); }; + Symbol kind_ = prim::CudaFusionGroup; + + // nvrtc has a limit on the number of arguments allowed in a CUDA kernel. + // The specific limit is a function of constant memory size, amount available + // to pass arguments, and some implementation dependence. Select a safe + // limit here. + // This limit is also applied to other devices in the fuser by default. + // Change with setInputArgLimit + size_t subgraph_arg_limit_ = 128; + + GraphFuser(Block* block, std::shared_ptr graph) + : block_(block), graph_(std::move(graph)) {} + + // Custom passes require kind to specified + GraphFuser( + Block* block, + std::shared_ptr graph, + FusionCallback callback, + Symbol kind) + : block_(block), + graph_(std::move(graph)), + callback_(callback), + kind_(kind) {} + + void setInputArgLimit(size_t limit) { + subgraph_arg_limit_ = limit; + } + + value_list tensorInputs(Node* node) { + return filter(node->inputs(), [](Value* v) { + return v->type()->isSubtypeOf(TensorType::get()); + }); + } + + bool isFusable(Node* node) { + return callback_(node); + } + + bool isFusableDevice(Value *v) { + if (!v->type()->isSubtypeOf(TensorType::get())) { + return true; + } + auto device = v->type()->expect()->device(); + if (!device) { + return true; + } + if ((*device).is_cpu()) { + return false; + } else if ((*device).is_cuda()) { + // We want to introduce minimal knobs exposed to user; + // turn on/off CUDA fuser through custom pass registration. + return true; + } + throw std::runtime_error("Unknown device"); + } + + + // Default fusability check - used when the user doesn't pass in + // a callback. + bool isFusableDefault(Node* node) { + bool fusableDevice = true; + for (const auto& output : node->outputs()) { + if (output->uses().size() > 0) { + fusableDevice &= isFusableDevice(output); + } + } + return fusableDevice && isFusableMap(node); + } + + bool isFusableMap(Node* node) { + // We don't want to bother with cross-block node movements, as they + // are not necessarily correct. + if (node->owningBlock() != block_) + return false; + return node->kind() == prim::CudaFusionGroup || isSimpleMap(node); + } + + bool isFusableCatNode(Node* node) { + if (node->kind() != aten::cat) + return false; + if (!node->is_constant(attr::dim)) + return false; + + auto tensors_node = node->namedInput(attr::tensors)->node(); + if ((tensors_node->inputs().size() + node->outputs().size()) > + subgraph_arg_limit_) { + return false; + } + if (tensors_node->kind() != prim::ListConstruct) + return false; + // NB: Note that technically other uses of the list aren't a big problem for + // us. It would be enough to place the prim::FusedConcat before the + // prim::ListConstruct, and allUsersAreThisConsumerOrOccurAfterIt would + // still be satisfied. However, I don't expect this to be necessary any time + // soon, and so we're simply assuming that we don't have to deal with it. + if (tensors_node->output()->uses().size() > 1) + return false; + return true; + } + + bool calculatesSize(Node* node) { + return node->matches("aten::size(Tensor self) -> int[]"); + } + + bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) { + auto defining_node = producer->node(); + for (auto o : defining_node->outputs()) { + for (auto u : o->uses()) { + if (u.user != consumer && !calculatesSize(u.user)) + return false; + } + } + return true; + } + + Graph& getSubgraph(Node* n) { + AT_ASSERT(n->kind() == kind_); + return *n->g(attr::Subgraph); + } + + void mergeCudaFusionGroups(Node* consumer_group, Node* producer_group) { + // Now we have two fusion groups! + // Revert the fusion - place all inner nodes of producer back in the outer + // graph. + std::vector temporary_nodes; + auto producer_subgraph = &getSubgraph(producer_group); + + // Initialize a map of inner graph values to outer graph values + std::unordered_map inner_to_outer; + auto inner_inputs = producer_subgraph->inputs(); + auto outer_inputs = producer_group->inputs(); + for (size_t i = 0; i < inner_inputs.size(); ++i) { + inner_to_outer[inner_inputs[i]] = outer_inputs[i]; + } + + // Clone all nodes + for (auto inner : producer_subgraph->nodes()) { + Node* outer = block_->owningGraph()->createClone( + inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); }); + outer->insertBefore(producer_group); + temporary_nodes.emplace_back(outer); + auto inner_outputs = inner->outputs(); + auto outer_outputs = outer->outputs(); + for (size_t i = 0; i < inner_outputs.size(); ++i) + inner_to_outer[inner_outputs[i]] = outer_outputs[i]; + } + + // Replace uses of producer_group outputs and destroy the producer + auto subgraph_outputs = producer_subgraph->outputs(); + for (size_t i = 0; i < subgraph_outputs.size(); ++i) { + auto outer_output = inner_to_outer.at(subgraph_outputs[i]); + producer_group->outputs()[i]->replaceAllUsesWith(outer_output); + } + producer_group->destroy(); + producer_group = + nullptr; // Just to get a clear error in case someone uses it + + // Inline the temporary nodes into the first group + auto consumer_subgraph = &getSubgraph(consumer_group); + for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend(); + ++it) { + Node* node = *it; + Node* merged = mergeNodeIntoGroup(consumer_group, node); + // If any of the outputs are still used then we need to add them + auto outputs = node->outputs(); + for (size_t i = 0; i < outputs.size(); ++i) { + auto output = outputs[i]; + if (output->uses().size() == 0) + continue; + consumer_subgraph->registerOutput(merged->outputs()[i]); + auto new_output = consumer_group->addOutput(); + output->replaceAllUsesWith(new_output); + new_output->setType(output->type()); + } + node->destroy(); + } + } + + // insert a producer node into a consuming fusion group. + // DOES NOT WORK if n is a consumer of an output of the fusion group + // returns the node _inside_ the group that represents the node + Node* mergeNodeIntoGroup(Node* group, Node* n) { + AT_ASSERT(n->kind() != kind_); + auto& subgraph = getSubgraph(group); + // map from nodes in the surrounding graph to parameters in the fusion + // group's subgraph that correspond to them + std::unordered_map inputs_map; + size_t i = 0; + size_t tensor_insert_idx = 0; + AT_ASSERT(group->inputs().size() == subgraph.inputs().size()); + for (auto input : group->inputs()) { + inputs_map[input] = subgraph.inputs()[i++]; + if (input->type()->isSubtypeOf(TensorType::get())) + tensor_insert_idx = i; + } + // add n's inputs to the fusion group's input list if we don't already have + // them + // we insert tensors first because the fuser assumes that to be the case + // (as a legacy from tensors only) + WithInsertPoint guard(*subgraph.nodes().begin()); + for (auto input : n->inputs()) { + if (inputs_map.count(input) == 0) { + if (input->type()->isSubtypeOf(TensorType::get())) { + auto in_group = subgraph.insertInput(tensor_insert_idx); + in_group->setType(input->type()); + inputs_map[input] = in_group; + group->insertInput(tensor_insert_idx, input); + tensor_insert_idx++; + } else if ( + (input->type()->isSubtypeOf(FloatType::get()) && + input->node()->kind() != prim::Constant) || + (n->kind() == aten::_grad_sum_to_size && + input->type()->isSubtypeOf(ListType::ofInts()))) { + auto in_group = subgraph.addInput(); + in_group->setType(input->type()); + inputs_map[input] = in_group; + group->addInput(input); + } else { + // We don't support passing in scalars as arguments to fused kernels, + // so we generally don't allow fusing tensor-scalar operations unless + // the scalar is constant. In those cases we inline the constants + // directly in the body of the fused group. + AT_ASSERT(input->node()->kind() == prim::Constant); + Node* in_const = + subgraph.createClone(input->node(), [](Value*) -> Value* { + throw std::runtime_error("unexpected input"); + }); + subgraph.insertNode(in_const); + inputs_map[input] = in_const->output(); + } + } + } + // copy n into the graph, remapping its inputs to internal nodes + Node* in_graph = subgraph.createClone( + n, [&](Value* k) -> Value* { return inputs_map[k]; }); + // if n's outputs are already inputs to the fusion group, + // we need to remove them because n is now inside the fusion group. + // + // i.e., + // x = f(w); group(x, y, z) becomes group(w, y, z). + // x, y, z = f(w); group(x, y, z) becomes group(w). + // + // remapping nodes that used the input to the newly-merged node + // n is not an input when the fusion group is empty + auto inputs = group->inputs(); + for (size_t i = 0; i < n->outputs().size(); ++i) { + auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]); + if (it != inputs.end()) { + size_t p = it - inputs.begin(); + group->removeInput(p); + subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]); + subgraph.eraseInput(p); + } + } + return subgraph.insertNode(in_graph); + } + + // turn consumer node n into a fusion group with just n inside + // to prepare for fusion and replace uses of n with the new group + Node* createSingletonCudaFusionGroup(Node* n) { + auto group = block_->owningGraph()->createWithSubgraph(kind_); + // propogate position information for the new node so we can always + // have a valid mapping + group->insertBefore(n); + Node* mergedNode = mergeNodeIntoGroup(group, n); + getSubgraph(group).registerOutput(mergedNode->output()); + auto sel = group->addOutput(); + sel->copyMetadata(n->output()); + n->replaceAllUsesWith(group); + n->destroy(); + return group; + } + + at::optional tryFuse(Node* consumer, Value* producer) { + // this handles cases where producer can be moved _into_ the fusion group of + // consumer. + // TODO: extend to fusion of consumer into _producer's_ fusion blob + // if the consumer allInputsAreThisProducer(consumer,producer) + // we can move the consumer up into the producer. + // but this requires better handling of merging fusion groups so it is not + // done now + bool shouldFuse = isFusable(producer->node()) && + // Rearrange nodes such that all uses of producer are after the + // consumer. Fusion will rewrite those later uses to use the version of + // producer generated by the fused blob. In this case, producer becomes + // an output of the fusion group. + aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer); + + if (!shouldFuse) { + return at::nullopt; + } + + if ((consumer->inputs().size() + consumer->outputs().size() + + producer->node()->inputs().size() + + producer->node()->outputs().size()) > subgraph_arg_limit_) { + return at::nullopt; + } + + auto group = consumer; + if (consumer->kind() != kind_) { + group = createSingletonCudaFusionGroup(consumer); + } + + if (producer->node()->kind() == kind_) { + mergeCudaFusionGroups(group, producer->node()); + return group; + } + AT_ASSERT(producer->node()->outputs().size() == 1); + Node* merged = mergeNodeIntoGroup(group, producer->node()); + // remaining uses of this producer can occur because we allow + // fusion in cases where uses remain after the consumer + // if these exist, re-route them to the version of producer + // created in CudaFusionGroup + if (producer->uses().size() != 0) { + getSubgraph(group).registerOutput(merged->output()); + Value* new_producer = group->addOutput(); + new_producer->copyMetadata(producer); + producer->replaceAllUsesWith(new_producer); + } + producer->node()->destroy(); + return group; + } + + bool canFuseChunk(Node* consumer, Value* producer) { + if (consumer->kind() != prim::CudaFusionGroup) { + return false; + } + // Does the chunk have constant chunks/dim? + auto* chunk = producer->node(); + if (chunk->kind() != prim::ConstantChunk) + return false; + // And all uses of the chunk are in this consumer + for (auto s : chunk->outputs()) { + for (auto u : s->uses()) { + if (u.user != consumer) { + return false; + } + } + } + // And isn't a no-op chunk (chunks == 1). Have CSE clean this up. + // We could fuse this but it's better to just delete the node. + if (chunk->i(attr::chunks) == 1) { + return false; + } + return true; + } + + c10::optional findFusedChunk(Node* group, Value* input) { + AT_ASSERT(group->kind() == prim::CudaFusionGroup); + auto it = std::find(group->inputs().begin(), group->inputs().end(), input); + if (it == group->inputs().end()) { + return c10::nullopt; + } + size_t input_index = it - group->inputs().begin(); + auto& subgraph = getSubgraph(group); + auto* subgraph_input = subgraph.inputs().at(input_index); + // If subgraph_input is an input to prim::ConstantChunk, it will have 1 use + auto* node = subgraph_input->uses().at(0).user; + if (node->kind() == prim::ConstantChunk) { + AT_ASSERT(subgraph_input->uses().size() == 1); + return node; + } + return c10::nullopt; + } + + void fuseChunkByReusingExistingFusedChunk( + Node* group, + Node* chunk, + Node* existingFusedChunk) { + if (chunk->outputs().size() != existingFusedChunk->outputs().size()) { + return; + } + auto& subgraph = getSubgraph(group); + for (size_t i = 0; i < chunk->outputs().size(); ++i) { + // Find the input to the CudaFusionGroup (group) + auto* replacement_val = existingFusedChunk->outputs().at(i); + auto* val = chunk->outputs().at(i); + auto it = std::find(group->inputs().begin(), group->inputs().end(), val); + auto input_index = it - group->inputs().begin(); + + // Rewrite the graph to use replacement_val + auto group_input = subgraph.inputs().at(input_index); + group_input->replaceAllUsesWith(replacement_val); + + // Remove the input, it's no longer needed + group->removeInput(input_index); + subgraph.eraseInput(input_index); + } + chunk->destroy(); + } + + // There are two invariants for prim::ConstantChunk: + // (1) the tensor input to prim::ConstantChunk must be an input to the fusion + // group (2) no two ConstantChunks in the same CudaFusionGroup can share a tensor + // input. + graph_node_list::iterator fuseChunk(Node* consumer, Value* producer) { + auto* chunk = producer->node(); + AT_ASSERT(consumer->kind() == prim::CudaFusionGroup); + AT_ASSERT(chunk->kind() == prim::ConstantChunk); + + // if producer's input is already an input to a prim::ConstantChunk node, + // we cannot add a new prim::ConstantChunk node because of invariant (2). + auto* chunked_tensor = producer->node()->input(); + if (auto existingFusedChunk = findFusedChunk(consumer, chunked_tensor)) { + fuseChunkByReusingExistingFusedChunk( + consumer, chunk, *existingFusedChunk); + return consumer->reverseIterator(); + } + + // Move prim::ConstantChunk into the CudaFusionGroup + mergeNodeIntoGroup(consumer, chunk); + chunk->destroy(); + return consumer->reverseIterator(); + } + + value_list sortReverseTopological(ArrayRef inputs) { + value_list result; + for (auto i : inputs) { + if (i->node()->owningBlock() == block_) { + result.push_back(i); + } + } + // Sort in reverse topological order + std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { + return a->node()->isAfter(b->node()); + }); + return result; + } + + graph_node_list::iterator scanNodeForChunks(Node* consumer) { + if (consumer->kind() == prim::CudaFusionGroup) { + auto inputs = sortReverseTopological(consumer->inputs()); + for (auto producer : inputs) { + if (!canFuseChunk(consumer, producer)) { + continue; + } + return fuseChunk(consumer, producer); + } + } + return ++consumer->reverseIterator(); + } + + at::ArrayRef broadcast_tensors(value_list inputs) { + AT_ASSERT(inputs.size() > 0); + auto* g = inputs[0]->owningGraph(); + auto* input_list = + g->insertNode(g->createList(TensorType::get(), inputs))->output(); + auto* output_list = g->insert(aten::broadcast_tensors, {input_list}); + auto* unpack_node = g->insertNode( + g->create(prim::ListUnpack, {output_list}, inputs.size())); + return unpack_node->outputs(); + } + + void insertExplicitBroadcast(Node* node) { + WithInsertPoint insert_guard{node}; + auto tensors = tensorInputs(node); + auto new_tensors = broadcast_tensors(tensors); + + // Replace tensors inputs with broadcasted values + auto new_tensors_it = new_tensors.begin(); + for (size_t i = 0; i < node->inputs().size(); ++i) { + if (node->inputs()[i]->type()->isSubtypeOf(TensorType::get())) { + AT_ASSERT(new_tensors_it != new_tensors.end()); + node->replaceInput(i, *(new_tensors_it++)); + } + } + } + + Node* promoteChunkToBroadcastingChunk(Node* chunk) { + AT_ASSERT(chunk->kind() == prim::ConstantChunk); + + size_t nchunks = chunk->i(attr::chunks); + Node* bchunk = + chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks); + bchunk->addInput(chunk->input()); + for (size_t i = 0; i < nchunks; ++i) { + auto* old_output = chunk->outputs().at(i); + auto* new_output = bchunk->outputs().at(i); + new_output->copyMetadata(old_output); + old_output->replaceAllUsesWith(new_output); + } + bchunk->copyAttributes(*chunk); + bchunk->insertAfter(chunk); + chunk->destroy(); + return bchunk; + } + + // in places where op can be fused into a consumer but chunk is in the way + // distribute chunk to op's operands: + // replace a,b = chunk(op(x,y,z)) with: + // x', y', z' = broadcast_tensors([x, y, z]) + // x0,x1 = chunk(x') (x0 has a's type, x1 has b's type) + // y0,y1 = chunk(y') (y0 has a's type, y1 has b's type) + // z0,z1 = chunk(z') (z0 has a's type, z1 has b's type) + // a = op(x0,y0,z0) (a,b have their same size but are now contiguous) + // b = op(x1,y1,x1) + // + // The graph fuser uses an intermediate prim::BroadcastingChunk node to + // represent this behavior concisely. BroadcastingChunk(x, y, z) broadcasts + // all of its inputs and then chunks each input, in order, the same way. + // The above graph is equivalent to: + // x0, x1, y0, y1, z0, z1 = BroadcastingChunk(x, y, z) + // a = op(x0,y0,z0) + // b = op(x1,y1,x1) + // + // NB: The explicit broadcast is important for correctness. + // Let's say we have: + // %z = aten::mul(%x, %y) + // %z.1, %z.2 = aten::chunk(%z, ...) + // ... = prim::CudaFusionGroup(%z.1, %z.2, ...) + // It's possible that %x and %y do not have the same size as %z and + // need to be expanded first so that they can be chunked like %z + // + // NB: Chunk motion only occurs with fusable consumers, which implies + // that there is always some other operation, e.g., a+b, that happens + // after the chunk, and will be put into the fusion group. This is + // important, because distributing the chunk changes the contiguity + // of a and b, and so the results would be invalid, except that we know + // that simple_mappable operations will restore contiguity before + // we exit the fusion group. + // + // NB: The intermediate BroadcastingChunk is important for moving chunks past + // more than one operation: the graph fuser is not able to easily move + // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusible + // ops + // x = f(v, w) + // z = g(x, y) + // a, b = chunk(z) + // c = h(a, b) + // becomes (with the broadcast_tensors + chunk approach): + // x = f(v, w) + // x', y' = broadcast_tensors([x, y]) + // ax, bx = chunk(x') + // ay, by = chunk(y') + // a = g(ax, ay) + // b = g(bx, by) + // c = h(a, b) + // The broadcast_tensors node makes it harder to move f into the resulting + // CudaFusionGroup of g, g, and h. Keeping the broadcasting and chunk behavior + // together results in: + // x = f(v, w) + // ax, bx, ay, by = BroadcastingChunk(x, y) + // a = g(ax, ay) + // b = g(bx, by) + // c = h(a, b) + // making it easier to move f after the BroadcastingChunk: + // ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w) + // ax = f(av, aw) + // by = f(bv, bw) + // a = g(ax, ay) + // b = g(bx, by) + // c = h(a, b) + + bool tryToMoveChunk(Node* consumer, Value* producer) { + // is the output from a chunk/bchunk node? + auto* chunk = producer->node(); + if (chunk->kind() != prim::ConstantChunk && + chunk->kind() != prim::BroadcastingChunk) + return false; + + // try to find a producer to move after the chunk/bchunk. The producer must + // be fusible into the consumer. + auto it = std::find_if( + chunk->inputs().begin(), + chunk->inputs().end(), + [&](Value* producer_for_chunk) { + return isFusableMap(producer_for_chunk->node()) && + allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk); + }); + if (it == chunk->inputs().end()) { + return false; + } + Value* producer_for_chunk = *it; + size_t producer_index = it - chunk->inputs().begin(); + + // all uses of the chunk must be in in this consumer + for (auto s : chunk->outputs()) { + for (auto u : s->uses()) { + if (u.user != consumer) + return false; + } + } + // multiple return operators + Node* producer_for_chunk_node = producer_for_chunk->node(); + AT_ASSERT(producer_for_chunk_node->outputs().size() == 1); + + // Convert chunk to bchunk, if it isn't one already. The bchunk represents a + // broadcast and one or more chunk operations. + auto* bchunk = chunk; + if (chunk->kind() == prim::ConstantChunk) { + bchunk = promoteChunkToBroadcastingChunk(chunk); + } + size_t nchunks = bchunk->i(attr::chunks); + WithInsertPoint guard(bchunk->next()); + + std::vector producer_chunk_outputs; + for (size_t i = 0; i < nchunks; i++) { + producer_chunk_outputs.push_back( + bchunk->output(nchunks * producer_index + i)); + } + + // Add each of op's operands to the bchunk node. + // chunked_inputs[input_nr][chunk_output_idx] + // = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr]) + std::vector> chunked_inputs; + + for (auto input : producer_for_chunk_node->inputs()) { + // XXX: we only work with pointwise ops in here, so we know it is valid to + // push the concat only through tensor arguments (and all other args can + // be safely ignored). + if (!input->type()->isSubtypeOf(TensorType::get())) + continue; + + // if 'input' is already an input to the bchunk, reuse it. + auto bchunk_inputs = bchunk->inputs(); + auto it = std::find(bchunk_inputs.begin(), bchunk_inputs.end(), input); + if (it != bchunk_inputs.end()) { + chunked_inputs.emplace_back(); + auto input_index = std::distance(bchunk_inputs.begin(), it); + for (size_t chunk = 0; chunk < nchunks; ++chunk) { + chunked_inputs.back().push_back( + bchunk->outputs().at(nchunks * input_index + chunk)); + } + continue; + } + + // NB: I decided not to use cloneFrom here, because if we make cloneFrom + // copy selects one day, it is definitely not what you want here (selects + // have different types). + // TODO: Perhaps we should use cloneFrom now, as it seems unlikely + // to copy select nodes now that we have refactored to have a Value + // distinct from Node. + bchunk->addInput(input); + chunked_inputs.emplace_back(); // alas, to not be C++17 + for (auto chunk_sel : producer_chunk_outputs) { + Value* input_chunk_sel = bchunk->addOutput(); + input_chunk_sel->setType(chunk_sel->type()); + chunked_inputs.back().push_back(input_chunk_sel); + } + } + + // apply the op to each chunk of the chunked operands, + // and then rewrite the graph to use them! + for (auto chunk_sel : producer_chunk_outputs) { + auto original_inputs = producer_for_chunk_node->inputs(); + Node* chunked_op = + block_->owningGraph()->create(producer_for_chunk_node->kind()); + chunked_op->copyAttributes(*producer_for_chunk_node); + chunked_op->output()->setType(chunk_sel->type()); + auto chunked_inputs_it = chunked_inputs.begin(); + for (Value* original_input : original_inputs) { + if (original_input->type()->isSubtypeOf(TensorType::get())) { + AT_ASSERT(chunked_inputs_it != chunked_inputs.end()); + chunked_op->addInput( + chunked_inputs_it->at(chunk_sel->offset() % nchunks)); + ++chunked_inputs_it; + } else { + chunked_op->addInput(original_input); + } + } + bchunk->owningGraph()->insertNode(chunked_op); + chunk_sel->replaceAllUsesWith(chunked_op->output()); + } + + bchunk->removeInput(producer_index); + for (size_t i = 0; i < nchunks; i++) { + bchunk->eraseOutput(nchunks * producer_index); + } + + // The output of producer_for_chunk_node could have been used in some + // aten::size operators, so we need to clean those up as well (we simply + // broadcast all its tensor inputs). + // We need to insert these early in the graph, i.e. immediately after + // the producer_for_chunk_node as we will have the _size_if_not_same + // that may be before the bchunk. + WithInsertPoint guard2(producer_for_chunk_node); + auto size_calc_uses = producer_for_chunk_node->output()->uses(); + if (!size_calc_uses.empty()) { + auto tensor_inputs = filter( + producer_for_chunk_node->inputs(), + [](Value* v) { return v->type()->isSubtypeOf(TensorType::get()); }); + auto tensor_sizes = fmap(tensor_inputs, [](Value* v) { + return v->owningGraph()->insert(aten::size, {v}); + }); + AT_ASSERT(!tensor_sizes.empty()); + Value* output_size = tensor_sizes.size() == 1 + ? tensor_sizes[0] + : broadcastSizes(tensor_sizes); + for (Use u : size_calc_uses) { + u.user->output()->replaceAllUsesWith(output_size); + u.user->destroy(); + } + } + producer_for_chunk_node->destroy(); + return true; + } + + // returns where to continue scanning, and whether any fusion was made + std::pair scanNode(Node* consumer) { + if (isFusable(consumer)) { + // handle inputs in reverse topological order as well... + // otherwise in f(a,a+b) it will appear a is used twice if we consider + // the f-a fusion before the f-(a+b) fusion first. + auto inputs = sortReverseTopological(consumer->inputs()); + for (auto producer : inputs) { + if (tryToMoveChunk(consumer, producer)) { + // the chunk before this consumer was re-arranged to allow fusion, + // we scan this consumer again to perform the fusion + return std::make_pair(consumer->reverseIterator(), true); + } + auto fusion_group = tryFuse(consumer, producer); + if (fusion_group) { + // after fusion, consumer moves into a CudaFusionGroup, so inputs is no + // longer valid so we rescan the new CudaFusionGroup for more fusions... + return std::make_pair(fusion_group.value()->reverseIterator(), true); + } + } + } + return std::make_pair(++consumer->reverseIterator(), false); + } + + void replaceIntermediateBroadcastingChunks() { + for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { + auto* node = *it; + ++it; // We might delete node, so increment the iterator now. + if (node->kind() != prim::BroadcastingChunk) { + continue; + } + auto* bchunk = node; + insertExplicitBroadcast(bchunk); + + auto* graph = block_->owningGraph(); + size_t nchunks = bchunk->i(attr::chunks); + WithInsertPoint guard(bchunk->next()); + + // Split the bchunk into bchunks.inputs().size() number of chunk nodes. + for (size_t input_offset = 0; input_offset < bchunk->inputs().size(); + input_offset++) { + auto* input = bchunk->inputs().at(input_offset); + + Node* new_chunk = + graph->insertNode(graph->create(prim::ConstantChunk, input, 0)); + new_chunk->copyAttributes(*bchunk); + for (size_t output_offset = 0; output_offset < nchunks; + output_offset++) { + auto new_output = new_chunk->addOutput(); + auto old_output = + bchunk->outputs().at(input_offset * nchunks + output_offset); + new_output->copyMetadata(old_output); + old_output->replaceAllUsesWith(new_output); + } + } + bchunk->destroy(); + } + } + + bool usedOnlyInSize(Value* v) { + const auto& uses = v->uses(); + return std::all_of(uses.begin(), uses.end(), [](const Use& u) { + return u.user->matches("aten::size(Tensor self) -> int[]"); + }); + } + + // Builds up expressions that compute shapes of all intermediates (and + // outputs) of the fusion group, based on the sizes of inputs. You should run + // DCE to remove those that you end up not using. + std::unordered_map buildShapeExpressions(Node* fusion_group) { + WithInsertPoint insert_guard{fusion_group->next()}; + std::unordered_map shape_of; + + Graph* graph = fusion_group->owningGraph(); + auto subgraph = fusion_group->g(attr::Subgraph); + + auto inputs = fusion_group->inputs(); + auto sinputs = subgraph->inputs(); + AT_ASSERT(inputs.size() == sinputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i]->type()->isSubtypeOf(TensorType::get())) { + shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]}); + } + } + + // When we have a guarantee that an output won't be removed, because it's + // used in expressions that don't involve size checks, we can use its size + // instead of computing a long chain of broadcasts, starting from the + // beginning of the kernel. + auto outputs = fusion_group->outputs(); + auto soutputs = subgraph->outputs(); + AT_ASSERT(outputs.size() == soutputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (usedOnlyInSize(outputs[i])) + continue; + shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]}); + } + + for (Node* n : subgraph->nodes()) { + // XXX: Use of shape_of.emplace is crucial to the output shape + // optimization! + if (n->kind() == prim::FusedConcat) { + // This is a bit more involved, because we have to account for the case + // when inputs have different shapes, but fortunately those tensors are + // always outputs, and so we can simply avoid replacing their queries, + // because it won't help us. + continue; + } + if (n->kind() == prim::Constant) { + continue; + } + if (n->kind() == prim::ConstantChunk) { + Node* sizes_node = graph->insertNode( + graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2)); + sizes_node->i_(attr::dim, n->i(attr::dim)); + sizes_node->i_(attr::chunks, n->i(attr::chunks)); + Value* regular_size = sizes_node->outputs().at(0); + Value* last_size = sizes_node->outputs().at(1); + regular_size->setType(ListType::ofInts()); + last_size->setType(ListType::ofInts()); + auto outputs = n->outputs(); + for (Value* o : outputs.slice(0, outputs.size() - 1)) { + shape_of.emplace(o, regular_size); + } + shape_of.emplace(outputs.at(outputs.size() - 1), last_size); + continue; + } + auto tensor_inputs = filter(n->inputs(), [](Value* v) { + return v->type()->isSubtypeOf(TensorType::get()); + }); + auto shapes = + fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); }); + AT_ASSERT(!shapes.empty()); + shape_of.emplace( + n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes)); + } + return shape_of; + } + + void removeOutputsUsedOnlyInSize(Node* fusion_group) { + if (fusion_group->kind() != prim::CudaFusionGroup) + return; + auto subgraph = fusion_group->g(attr::Subgraph); + + auto shape_of = buildShapeExpressions(fusion_group); + auto outputs = fusion_group->outputs().vec(); + auto soutputs = subgraph->outputs().vec(); + // XXX: Iterating in this order is not only good for performance reasons! + // It is also crucial for correctness (i has to reflect the current true + // index of outputs[i])! + for (int64_t i = static_cast(outputs.size()) - 1; i >= 0; --i) { + auto output = outputs[i]; + auto soutput = soutputs[i]; + if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) { + auto uses = output->uses(); + for (Use u : uses) { + AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]")); + u.user->output()->replaceAllUsesWith(shape_of.at(soutput)); + u.user->destroy(); + } + fusion_group->eraseOutput(i); + subgraph->eraseOutput(i); + } + } + } + + void refreshAliasDb() { + aliasDb_ = torch::make_unique(graph_); + } + + bool canFuseWithConcat(Value* producer, Node* before_check) { + if (!isFusable(producer->node())) { + return false; + } + // NB: it is important that this check happens after isFusable, which checks + // that the blocks match, and it's not a special node like prim::Param + if (!aliasDb_->couldMoveBeforeTopologically( + producer->node(), before_check)) { + return false; + } + + // If the number of kernel args could exceed the limit, skip. + if ((before_check->inputs().size() + before_check->outputs().size() + + producer->node()->inputs().size() + + producer->node()->outputs().size()) > subgraph_arg_limit_) { + return false; + } + + // Fusion groups can be merged with concat's group if and only if + // the value they produce isn't already coming from a concat + if (producer->node()->kind() == prim::CudaFusionGroup) { + auto subgraph = producer->node()->g(attr::Subgraph); + auto* node = subgraph->outputs().at(producer->offset())->node(); + return node->kind() != prim::FusedConcat; + } + return true; + } + + Node* createFusedConcat(Node* node) { + AT_ASSERT(node->kind() == aten::cat); + + Graph* graph = node->owningGraph(); + Node* list_construct = node->namedInput(attr::tensors)->node(); + int64_t dim = node->get(attr::dim).value(); + + Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs()) + ->i_(attr::dim, dim); + fused_cat->insertBefore(list_construct); + fused_cat->output()->copyMetadata(node->output()); + + // NB: this deletes the fused_cat node from the original graph + return createSingletonCudaFusionGroup(fused_cat); + } + + void fuseConcats() { + for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend(); + ++it) { + Node* cat = *it; + if (!isFusableCatNode(cat)) { + continue; + } + Node* list_construct = cat->namedInput(attr::tensors)->node(); + Node* fused_cat = createFusedConcat(cat); + Value* fused_cat_out = fused_cat->output(); + + auto sorted_inputs = sortReverseTopological(fused_cat->inputs()); + size_t input_idx = 0; + bool any_fused = false; + while (input_idx < sorted_inputs.size()) { + Value* input = sorted_inputs[input_idx++]; + if (!canFuseWithConcat(input, fused_cat)) { + continue; + } + any_fused = true; + auto maybe_group = tryFuse(fused_cat, input); + AT_ASSERT(maybe_group && maybe_group == fused_cat); + // We could have destroyed multiple inputs when performing this fusion, + // so we have to recompute the list and iterate over it again. + sorted_inputs = sortReverseTopological(fused_cat->inputs()); + input_idx = 0; + } + + if (any_fused) { + cat->output()->replaceAllUsesWith(fused_cat_out); + it.destroyCurrent(); + if (list_construct->output()->uses().empty()) { + list_construct->destroy(); + } + } else { + fused_cat->destroy(); + } + } + } + + void optimizeFusedGraphs() { + for (Node* node : block_->nodes()) { + if (node->kind() != prim::CudaFusionGroup) { + continue; + } + auto subgraph = node->g(attr::Subgraph); + EliminateDeadCode(subgraph); + EliminateCommonSubexpression(subgraph); + ConstantPooling(subgraph); + } + } + + void run() { + // Run the pass until no changes are made. + // This is necessary, because the algorithm can miss out on certain fusion + // opportunities if ran only once. Consider this graph: + // + // %1 = f(...) + // %2 = g(%1) + // %3 = h(%1) + // %4 = l(%3) + // return (%4, %2) + // + // where f, g, h, l are simple map ops. + // The first iteration will fuse %4 and %3, and see that %1 is an input, but + // can't be fused, because it has a different use before the fusion group + // in our topological ordering. Then, %2 will be considered, and fused with + // %1. If we do another iteration, the algorithm will consider the fusion of + // these two groups and fix the situation. + bool any_changed = true; + while (any_changed) { + any_changed = false; + refreshAliasDb(); + for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { + bool changed; + std::tie(it, changed) = scanNode(*it); + any_changed |= changed; + } + } + refreshAliasDb(); + + fuseConcats(); + + optimizeFusedGraphs(); + + // The graph fuser can add intermediate prim::BroadcastingChunk nodes. + // Replace them with broadcasts + chunks. + replaceIntermediateBroadcastingChunks(); + + // Fuse starting chunks into the group. + for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { + it = scanNodeForChunks(*it); + } + + // Remove outputs that have been added only because we need their size + for (Node* n : block_->nodes()) { + removeOutputsUsedOnlyInSize(n); + } + + for (Node* node : block_->nodes()) { + for (Block* sub_block : node->blocks()) { + GraphFuser(sub_block, graph_, callback_, kind_).run(); + } + } + } +}; + +void PeepholeOptimizeShapeExpressions(Block* block) { + auto nodes = block->nodes(); + for (auto it = nodes.begin(); it != nodes.end(); ++it) { + Node* node = *it; + for (Block* subblock : node->blocks()) { + PeepholeOptimizeShapeExpressions(subblock); + } + if (node->kind() == prim::BroadcastSizes) { + // Remove no-op broadcasts. + if (node->inputs().size() == 1) { + node->output()->replaceAllUsesWith(node->input()); + it.destroyCurrent(); + continue; + } + // Deduplicate inputs, but use their unique() values to ensure + // this process only depends on the graph. + std::map unique_to_value; + for (Value* input : node->inputs()) { + unique_to_value.emplace(input->unique(), input); + } + if (unique_to_value.size() != node->inputs().size()) { + std::vector inputs; + inputs.reserve(unique_to_value.size()); + for (auto& entry : unique_to_value) { + inputs.push_back(entry.second); + } + if (inputs.size() == 1) { + node->output()->replaceAllUsesWith(inputs[0]); + } else { + WithInsertPoint insert_guard{node}; + node->output()->replaceAllUsesWith(broadcastSizes(inputs)); + } + it.destroyCurrent(); + --it; // Revisit the node with deduplicated inputs + continue; + } + // Remove compose simple chains of broadcasts into a single node. + const auto& uses = node->output()->uses(); + if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) { + Node* user = uses[0].user; + user->removeInput(uses[0].offset); + // NB: we don't care about deduplication in here, as we will visit user + // later. + for (Value* i : node->inputs()) { + user->addInput(i); + } + it.destroyCurrent(); + } + } + } +} + +} // anonymous namespace + +void CudaFuseGraph(std::shared_ptr& graph) { + GraphFuser(graph->block(), graph).run(); + // After FuseGraph some common subexpressions may come back + EliminateCommonSubexpression(graph); + // We might have emitted a fair amount of useless shape propagating code, so + // remove it + EliminateDeadCode(graph); + // Improve the quality of shape propagation code that was left + PeepholeOptimizeShapeExpressions(graph->block()); +} + +void registerCudaFuseGraph() { + static bool not_registered = true; + if (not_registered) { + RegisterPostFusionPass pass(CudaFuseGraph); + not_registered = false; + } +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/cuda_graph_fuser.h b/torch/csrc/jit/passes/cuda_graph_fuser.h new file mode 100644 index 000000000000..54acf2a6cb3b --- /dev/null +++ b/torch/csrc/jit/passes/cuda_graph_fuser.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +// Register CudaFuseGraph in custom passes +TORCH_API void registerCudaFuseGraph(); + +// NB: Be sure to run DCE before fusion, because dead instructions +// can prevent fusion opportunities from being exploited. +// On Windows will noop, NYI +TORCH_API void CudaFuseGraph(std::shared_ptr& graph); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp index 51ce49a5efd1..a142355bc520 100644 --- a/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp @@ -12,7 +12,8 @@ namespace jit { // Autograd-aware bool canRunWithAutograd(Node* node) { auto kind = node->kind(); - return kind != prim::FusionGroup && (kind.is_aten() || kind.is_prim()); + return kind != prim::FusionGroup && kind != prim::CudaFusionGroup && + (kind.is_aten() || kind.is_prim()); } namespace { diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 0d6daa272bd5..4c52e9bcdaab 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -349,6 +350,8 @@ void initJITBindings(PyObject* module) { auto stack = toTraceableStack(args); checkAliasAnnotation(g, std::move(stack), unqualified_op_name); }) + .def( + "_jit_register_cuda_fuser", ®isterCudaFuseGraph) .def( "_jit_set_profiling_mode", [](bool profiling_flag) { diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 81f7a53122d8..2eba95b1084b 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -376,6 +376,9 @@ void initPythonIRBindings(PyObject* module_) { .def( "createFusionGroup", [](Graph& g) { return g.createWithSubgraph(prim::FusionGroup); }) + .def( + "createCudaFusionGroup", + [](Graph& g) { return g.createWithSubgraph(prim::CudaFusionGroup); }) .def( "createClone", [](Graph& g, Node* n, py::object fn) { diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 1d9de68e7111..9fef15afb897 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -178,6 +178,7 @@ bool printerHasSpecialCaseFor(Symbol sym) { prim::Drop, // used in interpreter only prim::FusedConcat, // optimization pass adds it prim::FusionGroup, // optimization pass adds it + prim::CudaFusionGroup, // optimization pass adds it prim::Load, // used in interpreter only prim::MMTreeReduce, // used as an optimization prim::MMBatchSide, // used as an optimization @@ -208,6 +209,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::If, prim::Loop, prim::FusionGroup, + prim::CudaFusionGroup, prim::DifferentiableGraph, prim::Constant, prim::Uninitialized, diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index aed269e29c7d..7513cbf28e28 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -299,6 +299,17 @@ RegisterOperators reg( }; }, aliasAnalysisSpecialCase()), + Operator( + prim::CudaFusionGroup, + [](const Node* node) -> Operation { + const auto key = registerFusion(node); + return [key](Stack& stack) { + RECORD_FUNCTION("CudaFusionGroup", std::vector()); + runFusion(key, stack); + return 0; + }; + }, + aliasAnalysisSpecialCase()), Operator( prim::FusionGroup, [](const Node* node) -> Operation {