[quant] Support set API for EmbeddingBag quantization (#43433)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43433

Add support for torch.quint8 dtype

Test Plan: Imported from OSS

Reviewed By: radkris-git

Differential Revision: D23277002

fbshipit-source-id: 4204bc62f124b4fd481aaa6aa47b9437978c43ee
This commit is contained in:
Supriya Rao
2020-08-24 14:31:31 -07:00
committed by Facebook GitHub Bot
parent e37f871e87
commit 284ff04792
2 changed files with 20 additions and 0 deletions

View File

@ -309,6 +309,8 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
default_qconfig = default_dynamic_qconfig
elif dtype is torch.float16:
default_qconfig = float16_dynamic_qconfig
elif dtype is torch.quint8:
default_qconfig = float_qparams_dynamic_qconfig
else:
raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))