mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 04:54:55 +08:00
Summary: There is a module called `2to3` which you can target for future specifically to remove these, the directory of `caffe2` has the most redundant imports: ```2to3 -f future -w caffe2``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/45033 Reviewed By: seemethere Differential Revision: D23808648 Pulled By: bugra fbshipit-source-id: 38971900f0fe43ab44a9168e57f2307580d36a38
122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
|
|
|
|
|
|
|
|
|
|
from caffe2.python import schema
|
|
from caffe2.python.layers.layers import ModelLayer
|
|
|
|
import numpy as np
|
|
|
|
|
|
class LayerNormalization(ModelLayer):
|
|
def __init__(
|
|
self,
|
|
model,
|
|
input_record,
|
|
name='layer_normalization',
|
|
scale_optim=None,
|
|
bias_optim=None,
|
|
epsilon=1e-4,
|
|
axis=1,
|
|
use_layer_norm_op=True,
|
|
scale_init_value=1.0,
|
|
**kwargs
|
|
):
|
|
super(LayerNormalization, self).__init__(
|
|
model, name, input_record, **kwargs)
|
|
|
|
assert isinstance(input_record, schema.Scalar), (
|
|
"Incorrect input type: {}".format(input_record))
|
|
|
|
self.input_shape = input_record.field_type().shape
|
|
self.axis = axis
|
|
|
|
assert len(self.input_shape) >= 1, (
|
|
"This layer supports only >= 2D tensors")
|
|
input_dims = self.input_shape[0]
|
|
|
|
self.output_schema = schema.Scalar(
|
|
(np.float32, self.input_shape),
|
|
self.get_next_blob_reference('output')
|
|
)
|
|
|
|
self.scale = self.create_param(param_name='scale',
|
|
shape=[input_dims],
|
|
initializer=('ConstantFill', {'value': scale_init_value}),
|
|
optimizer=scale_optim)
|
|
self.bias = self.create_param(param_name='bias',
|
|
shape=[input_dims],
|
|
initializer=('ConstantFill', {'value': 0.0}),
|
|
optimizer=bias_optim)
|
|
self.use_layer_norm_op = use_layer_norm_op
|
|
|
|
if self.use_layer_norm_op:
|
|
self.epsilon = epsilon
|
|
else:
|
|
assert len(self.input_shape) == 1, (
|
|
"When using alternative implementation, "
|
|
"input data can only be 2D"
|
|
)
|
|
self.epsilon = model.maybe_add_global_constant(
|
|
"%s_epsilon" % self.name, float(epsilon)
|
|
)
|
|
|
|
def add_ops_with_layer_norm_op(self, net):
|
|
input_blob = self.input_record.field_blobs()
|
|
ln_output = self.output_schema.field_blobs()
|
|
|
|
output_blobs = [net.NextScopedBlob('ln_output'), net.NextScopedBlob('ln_mean'),
|
|
net.NextScopedBlob('ln_stdev')]
|
|
|
|
normalized, mean, stdev = net.LayerNorm(input_blob,
|
|
output_blobs,
|
|
axis=self.axis,
|
|
epsilon=self.epsilon)
|
|
|
|
scaled = net.Mul(
|
|
[normalized, self.scale],
|
|
[net.NextScopedBlob('ln_scaled')],
|
|
broadcast=1,
|
|
axis=self.axis,
|
|
)
|
|
|
|
net.Add(
|
|
[scaled, self.bias],
|
|
ln_output,
|
|
broadcast=1,
|
|
axis=self.axis,
|
|
)
|
|
|
|
def add_ops_without_layer_norm_op(self, net):
|
|
# two issues here:
|
|
# 1. use multiple ops to replace the function of LayerNorm
|
|
# 2. do not use legacy broadcast
|
|
ln_output = net.NextScopedBlob("ln_output")
|
|
ln_mean = net.NextScopedBlob("ln_mean")
|
|
ln_stdev = net.NextScopedBlob("ln_stdev")
|
|
ln_mean_arr = net.NextScopedBlob("ln_mean_arr")
|
|
net.ReduceBackMean(self.input_record.field_blobs(), [ln_mean_arr])
|
|
net.ExpandDims([ln_mean_arr], [ln_mean], dims=[1])
|
|
ln_centered = net.NextScopedBlob("ln_centered")
|
|
net.Sub(self.input_record.field_blobs() + [ln_mean], [ln_centered])
|
|
ln_sqr = net.NextScopedBlob("ln_sqr")
|
|
net.Sqr([ln_centered], [ln_sqr])
|
|
ln_sqr_mean = net.NextScopedBlob("ln_sqr_mean")
|
|
net.ReduceBackMean([ln_sqr], [ln_sqr_mean])
|
|
ln_var = net.NextScopedBlob("ln_var")
|
|
net.Add([ln_sqr_mean, self.epsilon], ln_var)
|
|
ln_std_arr = net.NextScopedBlob("ln_std_arr")
|
|
net.Pow([ln_var], [ln_std_arr], exponent=0.5)
|
|
net.ExpandDims([ln_std_arr], [ln_stdev], dims=[1])
|
|
net.Div([ln_centered, ln_stdev], [ln_output])
|
|
ln_scaled = net.NextScopedBlob("ln_scaled")
|
|
net.Mul([ln_output, self.scale], [ln_scaled])
|
|
net.Add([ln_scaled, self.bias], self.output_schema.field_blobs())
|
|
|
|
def add_ops(self, net):
|
|
if self.use_layer_norm_op:
|
|
self.add_ops_with_layer_norm_op(net)
|
|
else:
|
|
self.add_ops_without_layer_norm_op(net)
|