[CUDA_FUSER] Fork CUDA fuser (#33527)

Summary:
Separating CUDA fuser from CPU fuser.

1. New node in IR - prim::CudaFusionGroup:
   This enables the cuda fuser to co-exist along side the old fuser. Allows us
   to incrementally build and expand cuda fuser.

2. copied FuseGraph optimization passes to CudaFuserGraph:
   We will re-factor & reuse Chunk/Concat in the old fuser logic, which is
   handled in the optimization pass at this moment. Unfortunately many code in
   the pass is tightly binded with the legacy fuser, which makes code sharing
   difficult.
   The CudaFusionGraph will support only a subset of operations comparing to
   legacy fuser (CUDA only). It is registered as a custom pass post fusion via
     ```torch._C._jit_register_cuda_fuser()```
   To have it in effect, you should also turn off fusion on GPU via
     ```torch._C._jit_override_can_fuse_on_gpu(False)```

3. We don't have codegen in this PR yet (WIP). Currently we just fall back to
   the old fuser.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33527

Differential Revision: D20171598

Pulled By: ZolotukhinM

fbshipit-source-id: 9a3c0f06f46da7eaa80ae7551c04869f5b03ef71
This commit is contained in:
Jie
2020-03-04 20:21:24 -08:00
committed by Facebook Github Bot
parent e132047f1b
commit 2b79bab029
12 changed files with 1274 additions and 1 deletions

View File

@ -33,6 +33,7 @@ namespace c10 {
_(prim, Eval) \
_(prim, Expand) /* onnx */ \
_(prim, FusionGroup) \
_(prim, CudaFusionGroup) \
_(prim, DifferentiableGraph) \
_(prim, If) \
_(prim, Jump) /* debug */ \

View File

@ -411,6 +411,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/passes/fixup_trace_scope_blocks.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/inline_fork_wait.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/graph_fuser.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/cuda_graph_fuser.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/guard_elimination.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/inplace_check.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/liveness.cpp

View File

@ -121,6 +121,7 @@ libtorch_sources = [
"torch/csrc/jit/passes/erase_number_types.cpp",
"torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp",
"torch/csrc/jit/passes/graph_fuser.cpp",
"torch/csrc/jit/passes/cuda_graph_fuser.cpp",
"torch/csrc/jit/passes/guard_elimination.cpp",
"torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp",
"torch/csrc/jit/passes/inliner.cpp",

View File

@ -326,6 +326,7 @@ void AliasDb::analyzeImpl(Node* node) {
case prim::Loop:
return analyzeLoop(node);
case prim::FusionGroup:
case prim::CudaFusionGroup:
case prim::DifferentiableGraph:
return analyzeSubgraph(node);
case prim::fork:

View File

@ -430,6 +430,7 @@ void Node::lint() const {
// longer.
break;
case prim::FusionGroup:
case prim::CudaFusionGroup:
checkSameDevice(this);
// TODO: Typecheck the parameters
g(attr::Subgraph)->lint();

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,17 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
// Register CudaFuseGraph in custom passes
TORCH_API void registerCudaFuseGraph();
// NB: Be sure to run DCE before fusion, because dead instructions
// can prevent fusion opportunities from being exploited.
// On Windows will noop, NYI
TORCH_API void CudaFuseGraph(std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch

View File

@ -12,7 +12,8 @@ namespace jit {
// Autograd-aware
bool canRunWithAutograd(Node* node) {
auto kind = node->kind();
return kind != prim::FusionGroup && (kind.is_aten() || kind.is_prim());
return kind != prim::FusionGroup && kind != prim::CudaFusionGroup &&
(kind.is_aten() || kind.is_prim());
}
namespace {

View File

@ -20,6 +20,7 @@
#include <torch/csrc/jit/passes/erase_number_types.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
#include <torch/csrc/jit/passes/inline_fork_wait.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/loop_unrolling.h>
@ -349,6 +350,8 @@ void initJITBindings(PyObject* module) {
auto stack = toTraceableStack(args);
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
})
.def(
"_jit_register_cuda_fuser", &registerCudaFuseGraph)
.def(
"_jit_set_profiling_mode",
[](bool profiling_flag) {

View File

@ -376,6 +376,9 @@ void initPythonIRBindings(PyObject* module_) {
.def(
"createFusionGroup",
[](Graph& g) { return g.createWithSubgraph(prim::FusionGroup); })
.def(
"createCudaFusionGroup",
[](Graph& g) { return g.createWithSubgraph(prim::CudaFusionGroup); })
.def(
"createClone",
[](Graph& g, Node* n, py::object fn) {

View File

@ -178,6 +178,7 @@ bool printerHasSpecialCaseFor(Symbol sym) {
prim::Drop, // used in interpreter only
prim::FusedConcat, // optimization pass adds it
prim::FusionGroup, // optimization pass adds it
prim::CudaFusionGroup, // optimization pass adds it
prim::Load, // used in interpreter only
prim::MMTreeReduce, // used as an optimization
prim::MMBatchSide, // used as an optimization
@ -208,6 +209,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
prim::If,
prim::Loop,
prim::FusionGroup,
prim::CudaFusionGroup,
prim::DifferentiableGraph,
prim::Constant,
prim::Uninitialized,

View File

@ -299,6 +299,17 @@ RegisterOperators reg(
};
},
aliasAnalysisSpecialCase()),
Operator(
prim::CudaFusionGroup,
[](const Node* node) -> Operation {
const auto key = registerFusion(node);
return [key](Stack& stack) {
RECORD_FUNCTION("CudaFusionGroup", std::vector<c10::IValue>());
runFusion(key, stack);
return 0;
};
},
aliasAnalysisSpecialCase()),
Operator(
prim::FusionGroup,
[](const Node* node) -> Operation {