mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 10:03:50 +08:00
scripts: update converter with quantization mode
This commit is contained in:
@ -195,6 +195,8 @@ class Converter(metaclass=ConverterMeta):
|
||||
# Set policy to "host_scalar" if is_host_scalar is True
|
||||
if param.is_host_scalar:
|
||||
policy = "host_scalar"
|
||||
if param.quantization_mode == "dynamic_mx":
|
||||
policy = "mx"
|
||||
result = f"{arg}:{policy}"
|
||||
if policy == "common" or policy == "host_scalar":
|
||||
result += f":{def_value}"
|
||||
|
@ -307,6 +307,7 @@ class QuantizationParam(Mapping):
|
||||
mask: int = 0
|
||||
groups: str = ""
|
||||
is_host_scalar: bool = False
|
||||
quantization_mode: str = ""
|
||||
|
||||
def __str__(self):
|
||||
if self.groups:
|
||||
|
@ -411,7 +411,7 @@ class ParserImpl:
|
||||
@staticmethod
|
||||
def parse_quantization_param(spec, read_value, param_type):
|
||||
# Old style: mask[:[value[*]|*]]
|
||||
# New style: mask[:data_type[:groups]]
|
||||
# New style: mask[:data_type[:host_scalar[:groups[:quantization_mode]]]]
|
||||
param = param_type()
|
||||
param.mask = spec.read_uint()
|
||||
if spec.read_literal(":"):
|
||||
@ -429,6 +429,9 @@ class ParserImpl:
|
||||
param.is_host_scalar = True
|
||||
else:
|
||||
param.groups = groups_or_host_flag
|
||||
if spec.read_literal(":"):
|
||||
param.quantization_mode = spec.read_str()
|
||||
|
||||
return param
|
||||
|
||||
# v2.7 and below
|
||||
|
Reference in New Issue
Block a user