mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 15:35:04 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73120 att This is to align our implementation with https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Imported from OSS Reviewed By: vkuzo Differential Revision: D34354038 fbshipit-source-id: 873a867e62bd541ef236974c697fac2334bf02ea (cherry picked from commit 3fce7cade2f057b985833659c2cb365ee4d6d9f3)
204 lines
8.0 KiB
Python
204 lines
8.0 KiB
Python
import torch
|
|
import torch.nn.quantized.functional
|
|
|
|
class LayerNorm(torch.nn.LayerNorm):
|
|
r"""This is the quantized version of :class:`~torch.nn.LayerNorm`.
|
|
|
|
Additional args:
|
|
* **scale** - quantization scale of the output, type: double.
|
|
* **zero_point** - quantization zero point of the output, type: long.
|
|
|
|
"""
|
|
|
|
def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
|
|
elementwise_affine=True, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super(LayerNorm, self).__init__(
|
|
normalized_shape, eps=eps, elementwise_affine=elementwise_affine,
|
|
**factory_kwargs)
|
|
self.weight = weight
|
|
self.bias = bias
|
|
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
|
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.layer_norm(
|
|
input, self.normalized_shape, weight=self.weight, bias=self.bias,
|
|
eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedLayerNorm'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
|
new_mod = cls(
|
|
mod.normalized_shape, mod.weight, mod.bias, float(scale),
|
|
int(zero_point), mod.eps, mod.elementwise_affine)
|
|
return new_mod
|
|
|
|
@classmethod
|
|
def from_reference(cls, mod, scale, zero_point):
|
|
return cls(
|
|
mod.normalized_shape, mod.weight, mod.bias, float(scale),
|
|
int(zero_point), mod.eps, mod.elementwise_affine)
|
|
|
|
class GroupNorm(torch.nn.GroupNorm):
|
|
r"""This is the quantized version of :class:`~torch.nn.GroupNorm`.
|
|
|
|
Additional args:
|
|
* **scale** - quantization scale of the output, type: double.
|
|
* **zero_point** - quantization zero point of the output, type: long.
|
|
|
|
"""
|
|
__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
|
|
|
|
def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5,
|
|
affine=True, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super(GroupNorm, self).__init__(num_groups, num_channels, eps, affine,
|
|
**factory_kwargs)
|
|
self.weight = weight
|
|
self.bias = bias
|
|
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
|
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.group_norm(
|
|
input, self.num_groups, self.weight, self.bias, self.eps, self.scale,
|
|
self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedGroupNorm'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
|
new_mod = cls(
|
|
mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point),
|
|
mod.eps, mod.affine)
|
|
return new_mod
|
|
|
|
class InstanceNorm1d(torch.nn.InstanceNorm1d):
|
|
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.
|
|
|
|
Additional args:
|
|
* **scale** - quantization scale of the output, type: double.
|
|
* **zero_point** - quantization zero point of the output, type: long.
|
|
|
|
"""
|
|
def __init__(self, num_features, weight, bias, scale, zero_point,
|
|
eps=1e-5, momentum=0.1, affine=False,
|
|
track_running_stats=False, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super(InstanceNorm1d, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
|
|
self.weight = weight
|
|
self.bias = bias
|
|
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
|
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.instance_norm(
|
|
input, self.weight, self.bias, self.eps, self.scale,
|
|
self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedInstanceNorm1d'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
|
new_mod = cls(
|
|
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
|
mod.eps, mod.affine)
|
|
return new_mod
|
|
|
|
@classmethod
|
|
def from_reference(cls, mod, scale, zero_point):
|
|
return cls(
|
|
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
|
mod.eps, mod.affine)
|
|
|
|
class InstanceNorm2d(torch.nn.InstanceNorm2d):
|
|
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.
|
|
|
|
Additional args:
|
|
* **scale** - quantization scale of the output, type: double.
|
|
* **zero_point** - quantization zero point of the output, type: long.
|
|
|
|
"""
|
|
def __init__(self, num_features, weight, bias, scale, zero_point,
|
|
eps=1e-5, momentum=0.1, affine=False,
|
|
track_running_stats=False, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super(InstanceNorm2d, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
|
|
self.weight = weight
|
|
self.bias = bias
|
|
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
|
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.instance_norm(
|
|
input, self.weight, self.bias, self.eps, self.scale,
|
|
self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedInstanceNorm2d'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
|
new_mod = cls(
|
|
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
|
mod.eps, mod.affine)
|
|
return new_mod
|
|
|
|
@classmethod
|
|
def from_reference(cls, mod, scale, zero_point):
|
|
return cls(
|
|
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
|
mod.eps, mod.affine)
|
|
|
|
class InstanceNorm3d(torch.nn.InstanceNorm3d):
|
|
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.
|
|
|
|
Additional args:
|
|
* **scale** - quantization scale of the output, type: double.
|
|
* **zero_point** - quantization zero point of the output, type: long.
|
|
|
|
"""
|
|
def __init__(self, num_features, weight, bias, scale, zero_point,
|
|
eps=1e-5, momentum=0.1, affine=False,
|
|
track_running_stats=False, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super(InstanceNorm3d, self).__init__(
|
|
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
|
|
self.weight = weight
|
|
self.bias = bias
|
|
self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
|
|
self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.instance_norm(
|
|
input, self.weight, self.bias, self.eps, self.scale,
|
|
self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedInstanceNorm3d'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
|
new_mod = cls(
|
|
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
|
mod.eps, mod.affine)
|
|
return new_mod
|
|
|
|
@classmethod
|
|
def from_reference(cls, mod, scale, zero_point):
|
|
return cls(
|
|
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
|
mod.eps, mod.affine)
|