mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[CUDA_FUSER] Fork CUDA fuser (#33527)
Summary: Separating CUDA fuser from CPU fuser. 1. New node in IR - prim::CudaFusionGroup: This enables the cuda fuser to co-exist along side the old fuser. Allows us to incrementally build and expand cuda fuser. 2. copied FuseGraph optimization passes to CudaFuserGraph: We will re-factor & reuse Chunk/Concat in the old fuser logic, which is handled in the optimization pass at this moment. Unfortunately many code in the pass is tightly binded with the legacy fuser, which makes code sharing difficult. The CudaFusionGraph will support only a subset of operations comparing to legacy fuser (CUDA only). It is registered as a custom pass post fusion via ```torch._C._jit_register_cuda_fuser()``` To have it in effect, you should also turn off fusion on GPU via ```torch._C._jit_override_can_fuse_on_gpu(False)``` 3. We don't have codegen in this PR yet (WIP). Currently we just fall back to the old fuser. Pull Request resolved: https://github.com/pytorch/pytorch/pull/33527 Differential Revision: D20171598 Pulled By: ZolotukhinM fbshipit-source-id: 9a3c0f06f46da7eaa80ae7551c04869f5b03ef71
This commit is contained in:
@ -33,6 +33,7 @@ namespace c10 {
|
||||
_(prim, Eval) \
|
||||
_(prim, Expand) /* onnx */ \
|
||||
_(prim, FusionGroup) \
|
||||
_(prim, CudaFusionGroup) \
|
||||
_(prim, DifferentiableGraph) \
|
||||
_(prim, If) \
|
||||
_(prim, Jump) /* debug */ \
|
||||
|
@ -411,6 +411,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/fixup_trace_scope_blocks.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/inline_fork_wait.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/graph_fuser.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/cuda_graph_fuser.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/guard_elimination.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/inplace_check.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/passes/liveness.cpp
|
||||
|
@ -121,6 +121,7 @@ libtorch_sources = [
|
||||
"torch/csrc/jit/passes/erase_number_types.cpp",
|
||||
"torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp",
|
||||
"torch/csrc/jit/passes/graph_fuser.cpp",
|
||||
"torch/csrc/jit/passes/cuda_graph_fuser.cpp",
|
||||
"torch/csrc/jit/passes/guard_elimination.cpp",
|
||||
"torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp",
|
||||
"torch/csrc/jit/passes/inliner.cpp",
|
||||
|
@ -326,6 +326,7 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
case prim::Loop:
|
||||
return analyzeLoop(node);
|
||||
case prim::FusionGroup:
|
||||
case prim::CudaFusionGroup:
|
||||
case prim::DifferentiableGraph:
|
||||
return analyzeSubgraph(node);
|
||||
case prim::fork:
|
||||
|
@ -430,6 +430,7 @@ void Node::lint() const {
|
||||
// longer.
|
||||
break;
|
||||
case prim::FusionGroup:
|
||||
case prim::CudaFusionGroup:
|
||||
checkSameDevice(this);
|
||||
// TODO: Typecheck the parameters
|
||||
g(attr::Subgraph)->lint();
|
||||
|
1231
torch/csrc/jit/passes/cuda_graph_fuser.cpp
Normal file
1231
torch/csrc/jit/passes/cuda_graph_fuser.cpp
Normal file
File diff suppressed because it is too large
Load Diff
17
torch/csrc/jit/passes/cuda_graph_fuser.h
Normal file
17
torch/csrc/jit/passes/cuda_graph_fuser.h
Normal file
@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Register CudaFuseGraph in custom passes
|
||||
TORCH_API void registerCudaFuseGraph();
|
||||
|
||||
// NB: Be sure to run DCE before fusion, because dead instructions
|
||||
// can prevent fusion opportunities from being exploited.
|
||||
// On Windows will noop, NYI
|
||||
TORCH_API void CudaFuseGraph(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -12,7 +12,8 @@ namespace jit {
|
||||
// Autograd-aware
|
||||
bool canRunWithAutograd(Node* node) {
|
||||
auto kind = node->kind();
|
||||
return kind != prim::FusionGroup && (kind.is_aten() || kind.is_prim());
|
||||
return kind != prim::FusionGroup && kind != prim::CudaFusionGroup &&
|
||||
(kind.is_aten() || kind.is_prim());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include <torch/csrc/jit/passes/erase_number_types.h>
|
||||
#include <torch/csrc/jit/passes/fuse_linear.h>
|
||||
#include <torch/csrc/jit/passes/graph_fuser.h>
|
||||
#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
|
||||
#include <torch/csrc/jit/passes/inline_fork_wait.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/passes/loop_unrolling.h>
|
||||
@ -349,6 +350,8 @@ void initJITBindings(PyObject* module) {
|
||||
auto stack = toTraceableStack(args);
|
||||
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
|
||||
})
|
||||
.def(
|
||||
"_jit_register_cuda_fuser", ®isterCudaFuseGraph)
|
||||
.def(
|
||||
"_jit_set_profiling_mode",
|
||||
[](bool profiling_flag) {
|
||||
|
@ -376,6 +376,9 @@ void initPythonIRBindings(PyObject* module_) {
|
||||
.def(
|
||||
"createFusionGroup",
|
||||
[](Graph& g) { return g.createWithSubgraph(prim::FusionGroup); })
|
||||
.def(
|
||||
"createCudaFusionGroup",
|
||||
[](Graph& g) { return g.createWithSubgraph(prim::CudaFusionGroup); })
|
||||
.def(
|
||||
"createClone",
|
||||
[](Graph& g, Node* n, py::object fn) {
|
||||
|
@ -178,6 +178,7 @@ bool printerHasSpecialCaseFor(Symbol sym) {
|
||||
prim::Drop, // used in interpreter only
|
||||
prim::FusedConcat, // optimization pass adds it
|
||||
prim::FusionGroup, // optimization pass adds it
|
||||
prim::CudaFusionGroup, // optimization pass adds it
|
||||
prim::Load, // used in interpreter only
|
||||
prim::MMTreeReduce, // used as an optimization
|
||||
prim::MMBatchSide, // used as an optimization
|
||||
@ -208,6 +209,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
|
||||
prim::If,
|
||||
prim::Loop,
|
||||
prim::FusionGroup,
|
||||
prim::CudaFusionGroup,
|
||||
prim::DifferentiableGraph,
|
||||
prim::Constant,
|
||||
prim::Uninitialized,
|
||||
|
@ -299,6 +299,17 @@ RegisterOperators reg(
|
||||
};
|
||||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
prim::CudaFusionGroup,
|
||||
[](const Node* node) -> Operation {
|
||||
const auto key = registerFusion(node);
|
||||
return [key](Stack& stack) {
|
||||
RECORD_FUNCTION("CudaFusionGroup", std::vector<c10::IValue>());
|
||||
runFusion(key, stack);
|
||||
return 0;
|
||||
};
|
||||
},
|
||||
aliasAnalysisSpecialCase()),
|
||||
Operator(
|
||||
prim::FusionGroup,
|
||||
[](const Node* node) -> Operation {
|
||||
|
Reference in New Issue
Block a user