mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable the quantization on XPU devices (#54857)
Summary: Enable the quantization on XPU devices. Keep the model as is if the model is on XPU devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/54857 Reviewed By: ailzhang Differential Revision: D28501381 Pulled By: jerryzh168 fbshipit-source-id: 6d3e9b04075393248b30776c69881f957a1a837c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ce3788d6a5
commit
618be18a41
@ -77,9 +77,11 @@ def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC,
|
||||
model_c = model._c
|
||||
model_c = torch._C._jit_pass_insert_quant_dequant(model_c, 'forward', inplace, debug, quant_type)
|
||||
if not debug:
|
||||
# Moving model parameters to CPU since quantized operators
|
||||
# are only supported on CPU right now
|
||||
model.cpu()
|
||||
is_xpu = all(p.device.type == 'xpu' for p in model.parameters())
|
||||
if not is_xpu:
|
||||
# Moving model parameters to CPU since quantized operators
|
||||
# are only supported on CPU and XPU right now
|
||||
model.cpu()
|
||||
if preserved_attrs is None:
|
||||
preserved_attrs = []
|
||||
model_c = torch._C._jit_pass_quant_finalize(model_c, quant_type, preserved_attrs)
|
||||
|
Reference in New Issue
Block a user