Enable the quantization on XPU devices (#54857)

Summary:
Enable the quantization on XPU devices. Keep the model as is if the model is on XPU devices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/54857

Reviewed By: ailzhang

Differential Revision: D28501381

Pulled By: jerryzh168

fbshipit-source-id: 6d3e9b04075393248b30776c69881f957a1a837c
This commit is contained in:
johnlu
2021-05-20 17:00:38 -07:00
committed by Facebook GitHub Bot
parent ce3788d6a5
commit 618be18a41

View File

@ -77,9 +77,11 @@ def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC,
model_c = model._c
model_c = torch._C._jit_pass_insert_quant_dequant(model_c, 'forward', inplace, debug, quant_type)
if not debug:
# Moving model parameters to CPU since quantized operators
# are only supported on CPU right now
model.cpu()
is_xpu = all(p.device.type == 'xpu' for p in model.parameters())
if not is_xpu:
# Moving model parameters to CPU since quantized operators
# are only supported on CPU and XPU right now
model.cpu()
if preserved_attrs is None:
preserved_attrs = []
model_c = torch._C._jit_pass_quant_finalize(model_c, quant_type, preserved_attrs)