add quantized layer norm implementation (#35329)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35329

Adds a quantized implementation of LayerNorm for server.

A future PR will add the Python wrapper.

Test Plan:
numerics match the floating point implementation

benchmarks by input size:
v1 (mean+var non-vectorized): https://gist.github.com/vkuzo/f6d72c04742608112f4c2e612c74bd13
v2 (mean+var vectorized in float): https://gist.github.com/vkuzo/4dd95657c5b5f3654e0965db00eff8d2
v3 (mean+var vectorized in int, current): https://gist.github.com/vkuzo/57a75f75629da9f23b64b38ca0e3d34b

Imported from OSS

Differential Revision: D20768930

fbshipit-source-id: ddf8727e9840c65ead3b890220af0638c5637028
This commit is contained in:
Vasiliy Kuznetsov
2020-04-09 08:58:24 -07:00
committed by Facebook GitHub Bot
parent 23e5f6a7be
commit f813e7184e
10 changed files with 383 additions and 6 deletions

View File

@ -0,0 +1,50 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for quantized layernorm operator."""
layernorm_configs_short = op_bench.cross_product_configs(
dims=(
(1, 8, 16),
(8, 8, 16),
(32, 8, 16),
(64, 128, 56, 56),
),
dtype=(torch.qint8,),
tags=["short"],
)
class QLayerNormBenchmark(op_bench.TorchBenchmarkBase):
def init(self, dims, dtype):
X = (torch.rand(*dims) - 0.5) * 256
scale = 1.0
zero_point = 0
self.qX = torch.quantize_per_tensor(
X, scale=scale, zero_point=zero_point, dtype=dtype)
self.weight = torch.rand(*self.qX.size()[1:], dtype=torch.float)
self.bias = torch.rand(*self.qX.size()[1:], dtype=torch.float)
self.eps = 1e-5
self.Y_scale = 0.1
self.Y_zero_point = 0
def forward(self):
return torch.ops.quantized.layer_norm(
self.qX, self.qX.size()[1:], weight=self.weight, bias=self.bias,
eps=self.eps, output_scale=self.Y_scale,
output_zero_point=self.Y_zero_point)
op_bench.generate_pt_test(layernorm_configs_short, QLayerNormBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()