mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
move batchnorm and layernorm fusion to decompose (#20337)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20337 ghimport-source-id: 2196f84f2ef384c1f25587b2fb4bd9dd2f63c2b4 Differential Revision: D15448596 Pulled By: wanchaol fbshipit-source-id: b66e608f1b72471fc0775aaa4e09f9fa1070fc3c
This commit is contained in:
committed by
Facebook Github Bot
parent
cde611a66c
commit
871c9dcb1d
@ -493,7 +493,7 @@ class TestFuser(JitTestCase):
|
||||
# test for layernorm decompose
|
||||
lm = nn.LayerNorm(8)
|
||||
test_norm_decompose(lm, ['aten::batch_norm_stats'],
|
||||
['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::addcmul'])
|
||||
['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add'])
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/custom_operator.h>
|
||||
#include <torch/csrc/jit/script/compiler.h>
|
||||
#include <torch/csrc/jit/passes/decompose_ops.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
@ -9,6 +10,73 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// helper to determine if an optional tensor argument/value passed in is
|
||||
// statically defined (neither a None constant nor a Optional[Tensor] type)
|
||||
// return yes, no, or no value if we can't tell
|
||||
c10::optional<bool> isDefined(Value* tensor) {
|
||||
if (tensor->type()->isSubtypeOf(TensorType::get())) {
|
||||
return true;
|
||||
}
|
||||
if (tensor->node()->mustBeNone()) {
|
||||
return false;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
bool isDecomposableNorm(Node* normalize_op) {
|
||||
static const OperatorSet decomposable_normalization_ops = {
|
||||
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
|
||||
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
|
||||
};
|
||||
Value* input = normalize_op->namedInput(attr::input);
|
||||
auto tensor_type = input->type()->cast<DimensionedTensorType>();
|
||||
// As of now, we do the decomposition for batchnorm/layernorm on GPU device only
|
||||
if (!tensor_type || tensor_type->device().is_cpu()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decomposable_normalization_ops.find(normalize_op)) {
|
||||
// If we can't determine if weight and bias is defined statically there's
|
||||
// really no point in decomposing normalization into simpler ops, since it
|
||||
// won't get fused into a single kernel.
|
||||
return isDefined(normalize_op->namedInput(attr::weight)).has_value() &&
|
||||
isDefined(normalize_op->namedInput(attr::bias)).has_value();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
RegisterOperators reg_bn_unsqueeze({Operator(
|
||||
"aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
const int64_t ndim = pop(stack).toInt();
|
||||
auto self = pop(stack).toTensor();
|
||||
c10::SmallVector<int64_t, 8> sizes(ndim, 1);
|
||||
AT_ASSERT(self.dim() == 1);
|
||||
sizes.at(1) = self.size(0);
|
||||
push(stack, self.reshape(sizes));
|
||||
return 0;
|
||||
};
|
||||
})});
|
||||
|
||||
RegisterOperators reg_ln_view({Operator(
|
||||
"aten::_ncf_view(Tensor self, int[] input_shape, int normalized_ndim) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
const int64_t normalized_ndim = pop(stack).toInt();
|
||||
auto input_shape = pop(stack).toIntListRef();
|
||||
auto self = pop(stack).toTensor();
|
||||
const int64_t input_ndim = input_shape.size();
|
||||
c10::SmallVector<int64_t, 8> sizes(input_ndim, 1);
|
||||
for (int i = 0; i < input_ndim - normalized_ndim; ++i) {
|
||||
sizes.at(i) = input_shape[i];
|
||||
}
|
||||
push(stack, self.reshape(sizes));
|
||||
return 0;
|
||||
};
|
||||
})});
|
||||
|
||||
|
||||
bool DecomposeOps(Block* block, script::CompilationUnit& decompose_funcs) {
|
||||
bool decomposed = false;
|
||||
for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
|
||||
@ -39,6 +107,74 @@ bool DecomposeOps(Block* block, script::CompilationUnit& decompose_funcs) {
|
||||
new_output->setType(it->output()->type());
|
||||
it->output()->replaceAllUsesWith(new_output);
|
||||
it.destroyCurrent();
|
||||
} else if (it->matches(
|
||||
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
|
||||
if (!isDecomposableNorm(*it)) {
|
||||
continue;
|
||||
}
|
||||
decomposed = true;
|
||||
WithInsertPoint insert_guard{*it};
|
||||
Graph* graph = it->owningGraph();
|
||||
Value* input = it->namedInput(attr::input);
|
||||
Value* input_dim = graph->insert(aten::dim, {input});
|
||||
std::vector<Value*> inputs {
|
||||
input,
|
||||
it->namedInput(attr::running_mean),
|
||||
it->namedInput(attr::running_var),
|
||||
it->namedInput(attr::training),
|
||||
it->namedInput(attr::momentum),
|
||||
it->namedInput(attr::eps)
|
||||
};
|
||||
|
||||
// inline the compiled decomposed batchnorm
|
||||
std::shared_ptr<Graph> d_graph = decompose_funcs.get_function("batch_norm").graph();
|
||||
Value* new_output = inlineCallTo(*graph, *d_graph, inputs).at(0);
|
||||
|
||||
// post processing the graph
|
||||
Value* weight = it->namedInput(attr::weight);
|
||||
Value* bias = it->namedInput(attr::bias);
|
||||
if (isDefined(weight).value()) {
|
||||
Value* expanded_weight =
|
||||
graph->insert(aten::_ncf_unsqueeze, {weight, input_dim});
|
||||
new_output = graph->insert(aten::mul, {new_output, expanded_weight});
|
||||
}
|
||||
if (isDefined(bias).value()) {
|
||||
Value* expanded_bias =
|
||||
graph->insert(aten::_ncf_unsqueeze, {bias, input_dim});
|
||||
new_output = graph->insert(aten::add, {new_output, expanded_bias});
|
||||
}
|
||||
it->output()->replaceAllUsesWith(new_output);
|
||||
it.destroyCurrent();
|
||||
} else if (it->matches(
|
||||
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor")) {
|
||||
if (!isDecomposableNorm(*it)) {
|
||||
continue;
|
||||
}
|
||||
decomposed = true;
|
||||
WithInsertPoint insert_guard{*it};
|
||||
Graph* graph = it->owningGraph();
|
||||
std::vector<Value*> inputs {
|
||||
it->namedInput(attr::input),
|
||||
it->namedInput(attr::normalized_shape),
|
||||
it->namedInput(attr::eps),
|
||||
it->namedInput(attr::cudnn_enable)
|
||||
};
|
||||
|
||||
// inline the compiled decomposed layernorm
|
||||
std::shared_ptr<Graph> d_graph = decompose_funcs.get_function("layer_norm").graph();
|
||||
Value* new_output = inlineCallTo(*graph, *d_graph, inputs).at(0);
|
||||
|
||||
// post processing the graph
|
||||
Value* weight = it->namedInput(attr::weight);
|
||||
Value* bias = it->namedInput(attr::bias);
|
||||
if (isDefined(weight).value()) {
|
||||
new_output = graph->insert(aten::mul, {new_output, weight});
|
||||
}
|
||||
if (isDefined(bias).value()) {
|
||||
new_output = graph->insert(aten::add, {new_output, bias});
|
||||
}
|
||||
it->output()->replaceAllUsesWith(new_output);
|
||||
it.destroyCurrent();
|
||||
}
|
||||
}
|
||||
return decomposed;
|
||||
@ -48,6 +184,31 @@ void DecomposeOps(std::shared_ptr<Graph>& graph) {
|
||||
static script::CompilationUnit decompose_funcs(R"SCRIPT(
|
||||
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: number = 1.0, alpha: number = 1.0):
|
||||
return self + mat1.mm(mat2)
|
||||
|
||||
def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
|
||||
if training:
|
||||
norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
|
||||
else:
|
||||
norm_mean = torch._unwrap_optional(running_mean)
|
||||
norm_var = torch._unwrap_optional(running_var)
|
||||
norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
|
||||
norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
|
||||
norm_invstd = 1 / (torch.sqrt(norm_var + eps))
|
||||
return ((input - norm_mean) * norm_invstd)
|
||||
|
||||
def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor:
|
||||
input_ndim = input.dim()
|
||||
normalized_ndim = len(normalized_shape)
|
||||
n = 1
|
||||
for i in range(input_ndim - normalized_ndim):
|
||||
n *= input.size(i)
|
||||
input_reshape = input.contiguous().view(1, n, -1)
|
||||
mean, invstd = torch.batch_norm_stats(input_reshape, eps)
|
||||
input_shape = input.size()
|
||||
mean = torch._ncf_view(mean, input_shape, normalized_ndim)
|
||||
invstd = torch._ncf_view(invstd, input_shape, normalized_ndim)
|
||||
|
||||
return (input - mean) * invstd
|
||||
)SCRIPT");
|
||||
bool is_decomposed = DecomposeOps(graph->block(), decompose_funcs);
|
||||
if (is_decomposed) {
|
||||
|
@ -120,64 +120,6 @@ bool isSimpleMap(Node* node) {
|
||||
return true;
|
||||
}
|
||||
|
||||
RegisterOperators reg_bn_unsqueeze({Operator(
|
||||
"aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
const int64_t ndim = pop(stack).toInt();
|
||||
auto self = pop(stack).toTensor();
|
||||
c10::SmallVector<int64_t, 8> sizes(ndim, 1);
|
||||
AT_ASSERT(self.dim() == 1);
|
||||
sizes.at(1) = self.size(0);
|
||||
push(stack, self.reshape(sizes));
|
||||
return 0;
|
||||
};
|
||||
})});
|
||||
|
||||
RegisterOperators reg_ln_view({Operator(
|
||||
"aten::_ncf_view(Tensor self, int[] input_shape, int normalized_ndim) -> Tensor",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
const int64_t normalized_ndim = pop(stack).toInt();
|
||||
auto input_shape = pop(stack).toIntListRef();
|
||||
auto self = pop(stack).toTensor();
|
||||
const int64_t input_ndim = input_shape.size();
|
||||
c10::SmallVector<int64_t, 8> sizes(input_ndim, 1);
|
||||
for (int i = 0; i < input_ndim - normalized_ndim; ++i) {
|
||||
sizes.at(i) = input_shape[i];
|
||||
}
|
||||
push(stack, self.reshape(sizes));
|
||||
return 0;
|
||||
};
|
||||
})});
|
||||
|
||||
// Yes, no, or no value if we can't tell
|
||||
c10::optional<bool> isDefined(Value* tensor) {
|
||||
if (tensor->type()->isSubtypeOf(TensorType::get())) {
|
||||
return true;
|
||||
}
|
||||
if (tensor->node()->mustBeNone()) {
|
||||
return false;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
bool isFusableNorm(Node* normalize_op) {
|
||||
static const OperatorSet decomposable_normalization_ops = {
|
||||
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
|
||||
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
|
||||
};
|
||||
|
||||
if (decomposable_normalization_ops.find(normalize_op)) {
|
||||
// If we can't determine if weight and bias is defined statically there's
|
||||
// really no point in decomposing normalization into simpler ops, since it
|
||||
// won't get fused into a single kernel.
|
||||
return isDefined(normalize_op->namedInput(attr::weight)).has_value() &&
|
||||
isDefined(normalize_op->namedInput(attr::bias)).has_value();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
|
||||
AT_ASSERT(!sizes.empty());
|
||||
Graph* graph = sizes[0]->owningGraph();
|
||||
@ -250,7 +192,7 @@ struct GraphFuser {
|
||||
fusableDevice &= isFusableDevice(output);
|
||||
}
|
||||
}
|
||||
return fusableDevice && (isFusableMap(node) || isFusableNorm(node));
|
||||
return fusableDevice && isFusableMap(node);
|
||||
}
|
||||
|
||||
bool isFusableMap(Node* node) {
|
||||
@ -312,113 +254,6 @@ struct GraphFuser {
|
||||
return *n->g(attr::Subgraph);
|
||||
}
|
||||
|
||||
Value* decomposeCommonNormalization(
|
||||
Node* normalization_op,
|
||||
const char* source,
|
||||
const std::string& method_name,
|
||||
const std::vector<Value*>& inputs) {
|
||||
std::shared_ptr<Graph> nm_graph;
|
||||
std::once_flag flag;
|
||||
std::call_once(
|
||||
flag,
|
||||
[](std::shared_ptr<Graph>* graph_ptr,
|
||||
const char* source,
|
||||
const std::string& method_name) {
|
||||
script::CompilationUnit cu;
|
||||
cu.define(source, script::nativeResolver(), nullptr);
|
||||
*graph_ptr = cu.get_function(method_name).graph();
|
||||
},
|
||||
&nm_graph,
|
||||
source,
|
||||
method_name);
|
||||
|
||||
WithInsertPoint insert_guard{normalization_op};
|
||||
return inlineCallTo(*normalization_op->owningGraph(), *nm_graph, inputs)
|
||||
.at(0);
|
||||
}
|
||||
|
||||
void decomposeNormalizationOps(Node* normalization_op) {
|
||||
static const char* bm_source = R"SCRIPT(
|
||||
def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
|
||||
if training:
|
||||
norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
|
||||
else:
|
||||
norm_mean = torch._unwrap_optional(running_mean)
|
||||
norm_var = torch._unwrap_optional(running_var)
|
||||
norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
|
||||
norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
|
||||
norm_invstd = 1 / (torch.sqrt(norm_var + eps))
|
||||
return ((input - norm_mean) * norm_invstd)
|
||||
)SCRIPT";
|
||||
static const char* lm_source = R"SCRIPT(
|
||||
def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor:
|
||||
input_ndim = input.dim()
|
||||
normalized_ndim = len(normalized_shape)
|
||||
n = 1
|
||||
for i in range(input_ndim - normalized_ndim):
|
||||
n *= input.size(i)
|
||||
input_reshape = input.contiguous().view(1, n, -1)
|
||||
mean, invstd = torch.batch_norm_stats(input_reshape, eps)
|
||||
input_shape = input.size()
|
||||
mean = torch._ncf_view(mean, input_shape, normalized_ndim)
|
||||
invstd = torch._ncf_view(invstd, input_shape, normalized_ndim)
|
||||
|
||||
return (input - mean) * invstd
|
||||
)SCRIPT";
|
||||
AT_ASSERT(isFusableNorm(normalization_op));
|
||||
WithInsertPoint insert_guard{normalization_op};
|
||||
Value* input = normalization_op->namedInput(attr::input);
|
||||
if (normalization_op->kind() == aten::batch_norm) {
|
||||
Value* input_dim = graph_->insert(aten::dim, {input});
|
||||
std::vector<Value*> inputs{
|
||||
input,
|
||||
normalization_op->namedInput(attr::running_mean),
|
||||
normalization_op->namedInput(attr::running_var),
|
||||
normalization_op->namedInput(attr::training),
|
||||
normalization_op->namedInput(attr::momentum),
|
||||
normalization_op->namedInput(attr::eps)};
|
||||
|
||||
Value* new_output = decomposeCommonNormalization(
|
||||
normalization_op, bm_source, "batch_norm", inputs);
|
||||
auto weight = normalization_op->namedInput(attr::weight);
|
||||
auto bias = normalization_op->namedInput(attr::bias);
|
||||
if (isDefined(weight).value()) {
|
||||
Value* expanded_weight =
|
||||
graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
|
||||
new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
|
||||
}
|
||||
if (isDefined(bias).value()) {
|
||||
Value* expanded_bias =
|
||||
graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
|
||||
new_output = graph_->insert(aten::add, {new_output, expanded_bias});
|
||||
}
|
||||
normalization_op->output()->replaceAllUsesWith(new_output);
|
||||
normalization_op->destroy();
|
||||
|
||||
} else if (normalization_op->kind() == aten::layer_norm) {
|
||||
std::vector<Value*> inputs{
|
||||
input,
|
||||
normalization_op->namedInput(attr::normalized_shape),
|
||||
normalization_op->namedInput(attr::eps),
|
||||
normalization_op->namedInput(attr::cudnn_enable)};
|
||||
Value* new_output = decomposeCommonNormalization(
|
||||
normalization_op, lm_source, "layer_norm", inputs);
|
||||
auto weight = normalization_op->namedInput(attr::weight);
|
||||
auto bias = normalization_op->namedInput(attr::bias);
|
||||
auto weight_defined = isDefined(weight).value();
|
||||
auto bias_defined = isDefined(bias).value();
|
||||
if (weight_defined && bias_defined) {
|
||||
new_output = graph_->insert(aten::addcmul, {bias, new_output, weight});
|
||||
} else if (weight_defined) {
|
||||
new_output = graph_->insert(aten::mul, {new_output, weight});
|
||||
} else if (bias_defined) {
|
||||
new_output = graph_->insert(aten::add, {new_output, bias});
|
||||
}
|
||||
normalization_op->output()->replaceAllUsesWith(new_output);
|
||||
normalization_op->destroy();
|
||||
}
|
||||
}
|
||||
|
||||
void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
|
||||
// Now we have two fusion groups!
|
||||
// Revert the fusion - place all inner nodes of producer back in the outer
|
||||
@ -619,18 +454,6 @@ struct GraphFuser {
|
||||
group = createSingletonFusionGroup(consumer);
|
||||
}
|
||||
|
||||
if (kind_ == prim::FusionGroup &&
|
||||
(producer->node()->matches(
|
||||
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor") ||
|
||||
producer->node()->matches(
|
||||
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor"))) {
|
||||
// We don't do any fusions in here, but simply decompose the normalization
|
||||
// ops into a kernel that computes the stats + pointwise ops which will be
|
||||
// considered in this fusion next.
|
||||
decomposeNormalizationOps(producer->node());
|
||||
return group;
|
||||
}
|
||||
|
||||
if (producer->node()->kind() == kind_) {
|
||||
mergeFusionGroups(group, producer->node());
|
||||
return group;
|
||||
|
Reference in New Issue
Block a user