# mypy: ignore-errors r"""Importing this file includes common utility methods and base classes for checking quantization api and properties of resulting modules. """ import torch import torch.ao.nn.intrinsic.quantized.dynamic as nniqd import torch.ao.nn.quantized as nnq import torch.ao.nn.quantized.dynamic as nnqd import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from functorch.experimental import control_flow from torch.ao.nn.intrinsic import _FusedModule from torch.ao.quantization import ( convert, default_dynamic_qat_qconfig, default_dynamic_qconfig, default_dynamic_quant_observer, default_embedding_qat_qconfig, default_observer, default_per_channel_qconfig, default_qconfig, default_symmetric_qnnpack_qat_qconfig, default_weight_observer, DeQuantStub, float_qparams_weight_only_qconfig, get_default_qat_qconfig, get_default_qat_qconfig_mapping, get_default_qconfig, get_default_qconfig_mapping, PerChannelMinMaxObserver, propagate_qconfig_, QConfig, QConfigMapping, quantize, quantize_dynamic_jit, quantize_jit, QuantStub, QuantType, QuantWrapper, ) from torch.ao.quantization.backend_config import get_executorch_backend_config from torch.ao.quantization.quantization_mappings import ( get_default_dynamic_quant_module_mappings, get_default_qat_module_mappings, get_default_qconfig_propagation_list, ) from torch.ao.quantization.quantize_pt2e import ( _convert_to_reference_decomposed_fx, convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) from torch.export import export_for_training from torch.jit.mobile import _load_for_lite_interpreter from torch.testing._internal.common_quantized import override_quantized_engine from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase try: from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph # graph mode quantization based on fx from torch.ao.quantization.quantize_fx import ( convert_fx, convert_to_reference_fx, prepare_fx, prepare_qat_fx, ) from torch.fx import GraphModule from torch.fx.graph import Node HAS_FX = True except ImportError: HAS_FX = False import contextlib import copy import functools import io import os import unittest from typing import Any, Callable, Optional, Union import numpy as np import torch._dynamo as torchdynamo import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer from torch.testing import FileCheck class NodeSpec: """Used for checking GraphModule Node""" def __init__(self, op, target): """ op: call_function | call_module target: for call_function, target would be a function for call_module, target would be the type of PyTorch module """ self.op = op self.target = target @classmethod def call_function(cls, target): return NodeSpec("call_function", target) @classmethod def call_method(cls, target): return NodeSpec("call_method", target) @classmethod def call_module(cls, target): return NodeSpec("call_module", target) def __hash__(self): return hash((self.op, self.target)) def __eq__(self, other): if not isinstance(other, NodeSpec): return NotImplemented return self.op == other.op and self.target == other.target def __repr__(self): return repr(self.op) + " " + repr(self.target) def get_supported_device_types(): return ( ["cpu", "cuda"] if torch.cuda.is_available() and not TEST_WITH_ROCM else ["cpu"] ) def test_only_eval_fn(model, calib_data): r""" Default evaluation function takes a torch.utils.data.Dataset or a list of input Tensors and run the model on the dataset """ for inp in calib_data: model(*inp) _default_loss_fn = torch.nn.CrossEntropyLoss() def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): r""" Default train function takes a torch.utils.data.Dataset and train the model on the dataset """ optimizer = torch.optim.Adam(model.parameters(), lr=0.001) train_loss, correct, total = 0, 0, 0 for _ in range(10): model.train() for data, target in train_data: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = torch.max(output, 1) total += target.size(0) correct += (predicted == target).sum().item() return train_loss, correct, total class AverageMeter: """Computes and stores the average and current value""" def __init__(self, name, fmt=":f"): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__) def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches): model.train() for cnt, (image, target) in enumerate(data_loader, start=1): print(".", end="") image, target = image.to(device), target.to(device) output = model(image) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() accuracy(output, target, topk=(1, 5)) if cnt >= ntrain_batches: return return def ddp_setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=world_size) def ddp_cleanup(): dist.destroy_process_group() def run_ddp(rank, world_size, prepared): ddp_setup(rank, world_size) prepared.cuda() prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank]) prepared.to(rank) model_with_ddp = prepared optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001) train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1) # noqa: F821 ddp_cleanup() def convert_dynamic(module): convert(module, get_default_dynamic_quant_module_mappings(), inplace=True) def prepare_dynamic(model, qconfig_dict=None): propagate_qconfig_(model, qconfig_dict) def _make_conv_test_input( batch_size, in_channels_per_group, input_feature_map_size, out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale, W_zero_point, use_bias, use_channelwise, ): in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( X_value_min, X_value_max, ( batch_size, in_channels, ) + input_feature_map_size, ) X = X_scale * (X_init - X_zero_point).float() X_q = torch.quantize_per_tensor( X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8 ) W_scale = W_scale * out_channels W_zero_point = W_zero_point * out_channels # Resize W_scale and W_zero_points arrays equal to out_channels W_scale = W_scale[:out_channels] W_zero_point = W_zero_point[:out_channels] # For testing, we use small values for weights and for activations so that # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in # qconv implementation and if there is no overflow. # In reference we can't exactly match the results with reference. # Please see the comment in qconv implementation file # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details. (W_value_min, W_value_max) = (-5, 5) # The operator expects them in the format # (out_channels, in_channels/groups,) + kernel_size W_init = torch.randint( W_value_min, W_value_max, ( out_channels, in_channels_per_group, ) + kernel_size, ) b_init = torch.randint(0, 10, (out_channels,)) if use_channelwise: W_shape = (-1, 1) + (1,) * len(kernel_size) W_scales_tensor = torch.tensor(W_scale, dtype=torch.float) W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float) W = ( W_scales_tensor.reshape(*W_shape) * (W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() ) b = X_scale * W_scales_tensor * b_init.float() W_q = torch.quantize_per_channel( W, W_scales_tensor.double(), W_zero_points_tensor.long(), 0, dtype=torch.qint8, ) else: W = W_scale[0] * (W_init - W_zero_point[0]).float() b = X_scale * W_scale[0] * b_init.float() W_q = torch.quantize_per_tensor( W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8 ) return (X, X_q, W, W_q, b if use_bias else None) def _make_conv_add_extra_input_tensor(scale, zero_point, sizes): (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( X_value_min, X_value_max, sizes, # Infer the size of tensor to do the add ) X = scale * (X_init - zero_point).float() X_q = torch.quantize_per_tensor( X, scale=scale, zero_point=zero_point, dtype=torch.quint8 ) return X, X_q def skipIfNoFBGEMM(fn): reason = "Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer." if isinstance(fn, type): if "fbgemm" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): if "fbgemm" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) return wrapper def skipIfNoQNNPACK(fn): reason = "Quantized operations require QNNPACK." if isinstance(fn, type): if "qnnpack" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): if "qnnpack" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) return wrapper def withQNNPACKBackend(fn): # TODO(future PR): consider combining with skipIfNoQNNPACK, # will require testing of existing callsites reason = "Quantized operations require QNNPACK." if isinstance(fn, type): if "qnnpack" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): if "qnnpack" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) with override_quantized_engine("qnnpack"): fn(*args, **kwargs) return wrapper def skipIfNoONEDNN(fn): reason = "Quantized operations require ONEDNN." if isinstance(fn, type): if "onednn" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): if "onednn" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) return wrapper def skipIfNoONEDNNBF16(fn): reason = "Quantized operations require BF16 support." if isinstance(fn, type): if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) return wrapper def skipIfNoX86(fn): reason = "Quantized operations require X86." if isinstance(fn, type): if "x86" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): if "x86" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) return wrapper def skipIfNoDynamoSupport(fn): reason = "dynamo doesn't support." if isinstance(fn, type): if not torchdynamo.is_dynamo_supported(): fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): if not torchdynamo.is_dynamo_supported(): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) return wrapper def skipIfNoInductorSupport(fn): reason = "inductor doesn't support." if isinstance(fn, type): if not torchdynamo.is_inductor_supported(): fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): if not torchdynamo.is_inductor_supported(): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) return wrapper try: import torchvision # noqa: F401 HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") def get_script_module(model, tracing, data): return torch.jit.trace(model, data) if tracing else torch.jit.script(model) def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True): """ Convert lengths to offsets for embedding_bag """ tt = np.zeros((t.shape[0] + 1,), dtype=offset_type) tt[1:] = t tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type)) if use_begin_offset: return tt[:-1] return tt[1:] def _group_quantize_tensor(w, n_bit=4, q_group_size=16): assert w.dim() == 2 w = w.transpose(0, 1).contiguous() assert q_group_size > 1 assert w.shape[-1] % q_group_size == 0 to_quant = w.reshape(-1, q_group_size) assert torch.isnan(to_quant).sum() == 0 max_val = to_quant.amax(dim=1, keepdim=True) min_val = to_quant.amin(dim=1, keepdim=True) max_int = 2**n_bit - 1 min_int = 0 scales = (max_val - min_val).clamp(min=1e-6) / max_int assert torch.isnan(scales).sum() == 0 zeros = min_val + scales * (2 ** (n_bit - 1)) assert torch.isnan(zeros).sum() == 0 out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) assert torch.isnan(out).sum() == 0 out = out.to(dtype=torch.int32).reshape(w.shape) if out.device != torch.device("cpu"): out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) # Scales and zeros for the same q-group should be contiguous, so we can # load as a 32-bit word scales = scales.view(w.shape[0], -1) zeros = zeros.view(w.shape[0], -1) scales_and_zeros = ( torch.cat( [ scales.reshape(scales.size(0), scales.size(1), 1), zeros.reshape(zeros.size(0), zeros.size(1), 1), ], 2, ) .transpose(0, 1) .contiguous() ) return out, scales_and_zeros def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32): # W is of shape [K x N] # We transpose W as Quantization is applied on [N x K] w = w.transpose(0, 1).contiguous() assert w.dim() == 2 assert groupsize > 1 assert w.shape[-1] % groupsize == 0 # Calculate scale and zeros to_quant = w.reshape(-1, groupsize) max_val = to_quant.abs().amax(dim=1, keepdim=True) eps = torch.finfo(max_val.dtype).eps max_int = 2 ** (n_bit - 1) - 1 # For 4-bit, this is 7 scales = max_val.clamp(min=eps) / max_int zeros = torch.zeros_like(scales) # Quantize the weight scales = scales.to(torch.float32).reshape(w.shape[0], -1) zeros = zeros.to(torch.float32).reshape(w.shape[0], -1) scales = scales.reshape(-1, 1) zeros = zeros.reshape(-1, 1) max_int = 2**n_bit - 1 w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int) # We pack 2 signed int4 values in unsigned uint8 container. # This reduces the weight size by half and improves load perf out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8) scales_and_zeros = scales.squeeze().contiguous() return out_uint8, scales_and_zeros def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): # source: https://github.com/meta-pytorch/gpt-fast/blob/main/quantize.py # default setup for affine quantization of activations x_dtype = x.dtype x = x.float() eps = torch.finfo(torch.float32).eps # get min and max min_val, max_val = torch.aminmax(x, dim=1) # calculate scales and zero_points based on min and max # reference: https://fburl.com/code/srbiybme min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) device = min_val_neg.device # reference: https://fburl.com/code/4wll53rk max_val_pos = torch.max(-min_val_neg, max_val_pos) scales = max_val_pos / (float(quant_max - quant_min) / 2) # ensure scales is the same dtype as the original tensor scales = torch.clamp(scales, min=eps).to(x.dtype) zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) # quantize based on qmin/qmax/scales/zp x_div = x / scales.unsqueeze(-1) x_round = torch.round(x_div) x_zp = x_round + zero_points.unsqueeze(-1) quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) return quant, scales.to(x_dtype), zero_points # QuantizationTestCase used as a base class for testing quantization on modules class QuantizationTestCase(TestCase): def setUp(self): super().setUp() self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)] self.train_data = [ [ torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long), ] for _ in range(2) ] self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] for _ in range(2)] self.img_data_2d = [ [torch.rand(1, 3, 10, 10, dtype=torch.float)] for _ in range(2) ] self.img_data_3d = [ [torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] for _ in range(2) ] self.img_data_1d_train = [ [ torch.rand(2, 3, 10, dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long), ] for _ in range(2) ] self.img_data_2d_train = [ [ torch.rand(1, 3, 10, 10, dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long), ] for _ in range(2) ] self.img_data_3d_train = [ [ torch.rand(1, 3, 5, 5, 5, dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long), ] for _ in range(2) ] self.img_data_dict = { 1: self.img_data_1d, 2: self.img_data_2d, 3: self.img_data_3d, } # Quant types that produce statically quantized ops self.static_quant_types = [QuantType.STATIC, QuantType.QAT] # All quant types for (fx based) graph mode quantization self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT] def checkNoPrepModules(self, module): r"""Checks the module does not contain child modules for quantization preparation, e.g. quant, dequant and observer """ self.assertFalse(hasattr(module, "quant")) self.assertFalse(hasattr(module, "dequant")) def checkNoQconfig(self, module): r"""Checks the module does not contain qconfig""" self.assertFalse(hasattr(module, "qconfig")) for child in module.children(): self.checkNoQconfig(child) def checkHasPrepModules(self, module): r"""Checks the module contains child modules for quantization preparation, e.g. quant, dequant and observer """ self.assertTrue(hasattr(module, "module")) self.assertTrue(hasattr(module, "quant")) self.assertTrue(hasattr(module, "dequant")) def checkObservers( self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None ): r"""Checks the module or module's leaf descendants have observers in preparation for quantization """ if propagate_qconfig_list is None: propagate_qconfig_list = get_default_qconfig_propagation_list() if prepare_custom_config_dict is None: prepare_custom_config_dict = {} float_to_observed_module_class_mapping = prepare_custom_config_dict.get( "float_to_observed_custom_module_class", {} ) # check if a module is a leaf module, ignoring activation_post_process attribute def is_leaf_module(module): submodule_name_count = 0 for name, _ in module.named_children(): if name != "activation_post_process": submodule_name_count += 1 return submodule_name_count == 0 if ( hasattr(module, "qconfig") and module.qconfig is not None and ( ( is_leaf_module(module) and not isinstance(module, torch.nn.Sequential) and type(module) in propagate_qconfig_list ) or type(module) in float_to_observed_module_class_mapping.keys() ) and not isinstance(module, torch.ao.quantization.DeQuantStub) ): self.assertTrue( hasattr(module, "activation_post_process"), "module: " + str(type(module)) + " do not have observer", ) # we don't need to check observers for child modules of the # qat modules if ( type(module) not in get_default_qat_module_mappings().values() and type(module) not in float_to_observed_module_class_mapping.values() and not isinstance(module, _FusedModule) ): for child in module.children(): if type(child) in [nn.Dropout]: continue self.checkObservers( child, propagate_qconfig_list, prepare_custom_config_dict ) def checkQuantDequant(self, mod): r"""Checks that mod has nn.Quantize and nn.DeQuantize submodules inserted """ self.assertEqual(type(mod.quant), nnq.Quantize) self.assertEqual(type(mod.dequant), nnq.DeQuantize) def checkWrappedQuantizedLinear(self, mod): r"""Checks that mod has been swapped for an nnq.Linear module, the bias is qint32, and that the module has Quantize and DeQuantize submodules """ self.assertEqual(type(mod.module), nnq.Linear) self.checkQuantDequant(mod) def checkQuantizedLinear(self, mod): self.assertEqual(type(mod), nnq.Linear) def checkDynamicQuantizedLinear(self, mod, dtype): r"""Checks that mod has been swapped for an nnqd.Linear module, the bias is float. """ self.assertEqual(type(mod), nnqd.Linear) self.assertEqual(mod._packed_params.dtype, dtype) def checkDynamicQuantizedLinearRelu(self, mod, dtype): r"""Checks that mod has been swapped for an nnqd.Linear module, the bias is float. """ self.assertEqual(type(mod), nniqd.LinearReLU) self.assertEqual(mod._packed_params.dtype, dtype) def check_eager_serialization(self, ref_model, loaded_model, x): # Check state dict serialization and torch.save APIs model_dict = ref_model.state_dict() b = io.BytesIO() torch.save(model_dict, b) b.seek(0) # weights_only=False as we sometimes get a ScriptObect here (weird) loaded_dict = torch.load(b, weights_only=False) loaded_model.load_state_dict(loaded_dict) ref_out = ref_model(*x) load_out = loaded_model(*x) def check_outputs(ref_out, load_out): self.assertEqual(ref_out[0], load_out[0]) if isinstance(ref_out[1], tuple): self.assertEqual(ref_out[1][0], load_out[1][0]) self.assertEqual(ref_out[1][1], load_out[1][1]) else: self.assertEqual(ref_out[1], load_out[1]) check_outputs(ref_out, load_out) b = io.BytesIO() torch.save(ref_model, b) b.seek(0) # weights_only=False as this is legacy code that saves the model loaded = torch.load(b, weights_only=False) load_out = loaded(*x) check_outputs(ref_out, load_out) def check_weight_bias_api(self, ref_model, weight_keys, bias_keys): weight = ref_model.get_weight() bias = ref_model.get_bias() self.assertEqual(weight_keys ^ weight.keys(), set()) self.assertEqual(bias_keys ^ bias.keys(), set()) def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype): r"""Checks that mod has been swapped for an nnqd.LSTM type module, the bias is float. """ wt_dtype_map = { torch.qint8: "quantized_dynamic", torch.float16: "quantized_fp16", } self.assertEqual(type(mod), reference_module_type) for packed_params in mod._all_weight_values: self.assertEqual( packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] ) def checkLinear(self, mod): self.assertEqual(type(mod), torch.nn.Linear) def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype): r"""Checks that mod has been swapped for an nnqd.Linear module, the bias is float. """ wt_dtype_map = { torch.qint8: "quantized_dynamic", torch.float16: "quantized_fp16", } self.assertEqual(type(mod), reference_module_type) if hasattr(mod, "_all_weight_values"): for packed_params in mod._all_weight_values: self.assertEqual( packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] ) def checkScriptable(self, orig_mod, calib_data, check_save_load=False): scripted = torch.jit.script(orig_mod) self._checkScriptable(orig_mod, scripted, calib_data, check_save_load) # Use first calib_data entry as trace input traced = torch.jit.trace(orig_mod, calib_data[0]) self._checkScriptable(orig_mod, traced, calib_data, check_save_load) # Call this twice: once for a scripted module and once for a traced module def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load): self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data) # Test save/load buffer = io.BytesIO() torch.jit.save(script_mod, buffer) buffer.seek(0) loaded_mod = torch.jit.load(buffer) # Pending __get_state_ and __set_state__ support # See tracking task https://github.com/pytorch/pytorch/issues/23984 if check_save_load: self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data) def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data): for inp in calib_data: ref_output = orig_mod(*inp) scripted_output = test_mod(*inp) self.assertEqual(scripted_output, ref_output) def checkGraphModeOp( self, module, inputs, quantized_op, tracing=False, debug=False, check=True, eval_mode=True, dynamic=False, qconfig=None, ): if debug: print("Testing:", str(module)) qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} if eval_mode: module = module.eval() if dynamic: qconfig_dict = {"": default_dynamic_qconfig if qconfig is None else qconfig} model = get_script_module(module, tracing, inputs[0]).eval() if debug: print("input graph:", model.graph) models = {} outputs = {} for debug in [True, False]: if dynamic: models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug) # make sure it runs outputs[debug] = models[debug](inputs) else: # module under test can contain in-place ops, and we depend on # input data staying constant for comparisons inputs_copy = copy.deepcopy(inputs) models[debug] = quantize_jit( model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False, debug=debug, ) # make sure it runs outputs[debug] = models[debug](*inputs[0]) if debug: print("debug graph:", models[True].graph) print("non debug graph:", models[False].graph) if check: # debug and non-debug option should have the same numerics self.assertEqual(outputs[True], outputs[False]) # non debug graph should produce quantized op FileCheck().check(quantized_op).run(models[False].graph) return models[False] def checkGraphModuleNodes( self, graph_module, expected_node=None, expected_node_occurrence=None, expected_node_list=None, ): """Check if GraphModule contains the target node Args: graph_module: the GraphModule instance we want to check expected_node, expected_node_occurrence, expected_node_list: see docs for checkGraphModeFxOp """ nodes_in_graph = {} node_list = [] modules = dict(graph_module.named_modules(remove_duplicate=False)) for node in graph_module.graph.nodes: n = None if node.op == "call_function" or node.op == "call_method": n = NodeSpec(node.op, node.target) elif node.op == "call_module": n = NodeSpec(node.op, type(modules[node.target])) if n is not None: node_list.append(n) if n in nodes_in_graph: nodes_in_graph[n] += 1 else: nodes_in_graph[n] = 1 if expected_node is not None: self.assertTrue( expected_node in nodes_in_graph, "node:" + str(expected_node) + " not found in the graph module", ) if expected_node_occurrence is not None: for expected_node, occurrence in expected_node_occurrence.items(): if occurrence != 0: self.assertTrue( expected_node in nodes_in_graph, "Check failed for node:" + str(expected_node) + " not found", ) self.assertTrue( nodes_in_graph[expected_node] == occurrence, "Check failed for node:" + str(expected_node) + " Expected occurrence:" + str(occurrence) + " Found occurrence:" + str(nodes_in_graph[expected_node]), ) else: self.assertTrue( expected_node not in nodes_in_graph, "Check failed for node:" + str(expected_node) + " expected no occurrence but found", ) if expected_node_list is not None: cur_index = 0 for n in node_list: if cur_index == len(expected_node_list): return if n == expected_node_list[cur_index]: cur_index += 1 self.assertTrue( cur_index == len(expected_node_list), "Check failed for graph:" + self.printGraphModule(graph_module, print_str=False) + "Expected ordered list:" + str(expected_node_list), ) def printGraphModule(self, graph_module, print_str=True): modules = dict(graph_module.named_modules(remove_duplicate=False)) node_infos = [] for n in graph_module.graph.nodes: node_info = " ".join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) if n.op == "call_module": node_info += " module type: " + repr(type(modules[n.target])) node_infos.append(node_info) str_to_print = "\n".join(node_infos) if print_str: print(str_to_print) return str_to_print if HAS_FX: def assert_types_for_matched_subgraph_pairs( self, matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]], expected_types: dict[ str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]] ], gm_a: GraphModule, gm_b: GraphModule, ) -> None: """ Verifies that the types specified in expected_types match the underlying objects pointed to by the nodes in matched_subgraph_pairs. An example successful test case: matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)} expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)} The function tests for key equivalence, and verifies types with instance checks. """ def _get_underlying_op_type( node: Node, gm: GraphModule ) -> Union[Callable, str]: if node.op == "call_module": mod = getattr(gm, node.target) return type(mod) else: assert node.op in ("call_function", "call_method") return node.target self.assertTrue( len(matched_subgraph_pairs) == len(expected_types), f"Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}", ) for k, v in expected_types.items(): expected_types_a, expected_types_b = v exp_type_start_a, exp_type_end_a = expected_types_a exp_type_start_b, exp_type_end_b = expected_types_b subgraph_a, subgraph_b = matched_subgraph_pairs[k] act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a) act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b) act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a) act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b) types_match = ( (exp_type_start_a is act_type_start_a) and (exp_type_end_a is act_type_end_a) and (exp_type_start_b is act_type_start_b) and (exp_type_end_b is act_type_end_b) ) self.assertTrue( types_match, f"Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, " f"got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}", ) def assert_ns_compare_dict_valid( self, act_compare_dict: dict[str, dict[str, dict[str, Any]]], ) -> None: """ Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid: 1. for each layer, results are recorded for two models 2. number of seen tensors match 3. shapes of each pair of seen tensors match """ for layer_name, result_type_to_data in act_compare_dict.items(): for result_type, layer_data in result_type_to_data.items(): self.assertTrue( len(layer_data) == 2, f"Layer {layer_name} does not have exactly two model results.", ) model_name_0, model_name_1 = layer_data.keys() for res_idx in range(len(layer_data[model_name_0])): layer_data_0 = layer_data[model_name_0][res_idx] layer_data_1 = layer_data[model_name_1][res_idx] self.assertTrue( layer_data_0["type"] == layer_data_0["type"], f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.", ) self.assertTrue( len(layer_data_0["values"]) == len(layer_data_1["values"]), f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.", ) # F.conv1d weight has rank 3, and toq.conv1d unpacked weight # has rank 4. For now, skip the length check for conv1d only. is_weight_functional_conv1d = ( result_type == NSSingleResultValuesType.WEIGHT.value and ( "conv1d" in layer_data_0["prev_node_target_type"] or "conv1d" in layer_data_1["prev_node_target_type"] ) ) if not is_weight_functional_conv1d: for idx in range(len(layer_data_0["values"])): values_0 = layer_data_0["values"][idx] values_1 = layer_data_1["values"][idx] if isinstance(values_0, torch.Tensor): self.assertTrue( values_0.shape == values_1.shape, f"Layer {layer_name}, {model_name_0} and {model_name_1} " + f"have a shape mismatch at idx {idx}.", ) elif isinstance(values_0, list): values_0 = values_0[0] values_1 = values_1[0] self.assertTrue( values_0.shape == values_1.shape, f"Layer {layer_name}, {model_name_0} and {model_name_1} " + f"have a shape mismatch at idx {idx}.", ) else: assert isinstance( values_0, tuple ), f"unhandled type {type(values_0)}" assert len(values_0) == 2 assert len(values_0[1]) == 2 assert values_0[0].shape == values_1[0].shape assert values_0[1][0].shape == values_1[1][0].shape assert values_0[1][1].shape == values_1[1][1].shape # verify that ref_node_name is valid ref_node_name_0 = layer_data_0["ref_node_name"] ref_node_name_1 = layer_data_1["ref_node_name"] prev_node_name_0 = layer_data_0["prev_node_name"] prev_node_name_1 = layer_data_1["prev_node_name"] if ( layer_data_0["type"] == NSSingleResultValuesType.NODE_OUTPUT.value ): self.assertTrue(ref_node_name_0 == prev_node_name_0) self.assertTrue(ref_node_name_1 == prev_node_name_1) elif ( layer_data_0["type"] == NSSingleResultValuesType.NODE_INPUT.value ): self.assertTrue(ref_node_name_0 != prev_node_name_0) self.assertTrue(ref_node_name_1 != prev_node_name_1) def checkGraphModeFxOp( self, model, inputs, quant_type, expected_node=None, expected_node_occurrence=None, expected_node_list=None, is_reference=False, print_debug_info=False, custom_qconfig_dict=None, prepare_expected_node=None, prepare_expected_node_occurrence=None, prepare_expected_node_list=None, prepare_custom_config=None, backend_config=None, ): """Quantizes model with graph mode quantization on fx and check if the quantized model contains the quantized_node Args: model: floating point torch.nn.Module inputs: one positional sample input arguments for model expected_node: NodeSpec e.g. NodeSpec.call_function(torch.quantize_per_tensor) expected_node_occurrence: a dict from NodeSpec to expected number of occurrences (int) e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, NodeSpec.call_method('dequantize'): 1} expected_node_list: a list of NodeSpec, used to check the order of the occurrence of Node e.g. [NodeSpec.call_function(torch.quantize_per_tensor), NodeSpec.call_module(nnq.Conv2d), NodeSpec.call_function(F.hardtanh_), NodeSpec.call_method('dequantize')] is_reference: if True, enables reference mode print_debug_info: if True, prints debug info custom_qconfig_dict: overrides default qconfig_dict prepare_expected_node: same as expected_node, but for prepare prepare_expected_node_occurrence: same as expected_node_occurrence, but for prepare prepare_expected_node_list: same as expected_node_list, but for prepare Returns: A dictionary with the following structure: { "prepared": ..., # the prepared model "quantized": ..., # the quantized non-reference model "quantized_reference": ..., # the quantized reference model "result": ..., # the result for either quantized or # quantized_reference model depending on the # is_reference argument } """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: qconfig_mapping = get_default_qat_qconfig_mapping( torch.backends.quantized.engine ) model.train() elif quant_type == QuantType.STATIC: qconfig_mapping = get_default_qconfig_mapping( torch.backends.quantized.engine ) model.eval() else: qconfig = default_dynamic_qconfig qconfig_mapping = QConfigMapping().set_global(qconfig) model.eval() if quant_type == QuantType.QAT: prepare = prepare_qat_fx else: prepare = prepare_fx # overwrite qconfig_dict with custom_qconfig_dict if custom_qconfig_dict is not None: assert type(custom_qconfig_dict) in ( QConfigMapping, dict, ), "custom_qconfig_dict should be a QConfigMapping or a dict" if isinstance(custom_qconfig_dict, QConfigMapping): qconfig_mapping = custom_qconfig_dict else: qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict) prepared = prepare( model, qconfig_mapping, example_inputs=inputs, prepare_custom_config=prepare_custom_config, backend_config=backend_config, ) if not quant_type == QuantType.DYNAMIC: prepared(*inputs) if print_debug_info: print() print("quant type:\n", quant_type) print("original model:\n", model) print() print("prepared model:\n", prepared) self.checkGraphModuleNodes( prepared, prepare_expected_node, prepare_expected_node_occurrence, prepare_expected_node_list, ) prepared_copy = copy.deepcopy(prepared) qgraph = convert_fx(copy.deepcopy(prepared)) qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared)) result = qgraph(*inputs) result_reference = qgraph_reference(*inputs) qgraph_copy = copy.deepcopy(qgraph) qgraph_reference_copy = copy.deepcopy(qgraph_reference) qgraph_to_check = qgraph_reference if is_reference else qgraph if print_debug_info: print() print("quantized model:\n", qgraph_to_check) self.printGraphModule(qgraph_to_check) print() self.checkGraphModuleNodes( qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list, ) return { "prepared": prepared_copy, "quantized": qgraph_copy, "quantized_reference": qgraph_reference_copy, "quantized_output": result, "quantized_reference_output": result_reference, } def checkEmbeddingSerialization( self, qemb, num_embeddings, embedding_dim, indices, offsets, set_qconfig, is_emb_bag, dtype=torch.quint8, ): # Test serialization of dynamic EmbeddingBag module using state_dict if is_emb_bag: inputs = [indices, offsets] else: inputs = [indices] emb_dict = qemb.state_dict() b = io.BytesIO() torch.save(emb_dict, b) b.seek(0) loaded_dict = torch.load(b) embedding_unpack = torch.ops.quantized.embedding_bag_unpack # Check unpacked weight values explicitly for key in emb_dict: if isinstance(emb_dict[key], torch._C.ScriptObject): assert isinstance(loaded_dict[key], torch._C.ScriptObject) emb_weight = embedding_unpack(emb_dict[key]) loaded_weight = embedding_unpack(loaded_dict[key]) self.assertEqual(emb_weight, loaded_weight) # Check state dict serialization and torch.save APIs if is_emb_bag: loaded_qemb = nnq.EmbeddingBag( num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, mode="sum", dtype=dtype, ) else: loaded_qemb = nnq.Embedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype ) self.check_eager_serialization(qemb, loaded_qemb, inputs) loaded_qemb.load_state_dict(loaded_dict) self.assertEqual( embedding_unpack(qemb._packed_params._packed_weight), embedding_unpack(loaded_qemb._packed_params._packed_weight), ) # Test JIT serialization self.checkScriptable(qemb, [inputs], check_save_load=True) # Test from_float call if is_emb_bag: float_embedding = torch.nn.EmbeddingBag( num_embeddings=num_embeddings, embedding_dim=embedding_dim, include_last_offset=True, scale_grad_by_freq=False, mode="sum", ) else: float_embedding = torch.nn.Embedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim ) if set_qconfig: float_qparams_observer = PerChannelMinMaxObserver.with_args( dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 ) float_embedding.qconfig = QConfig( activation=default_dynamic_quant_observer, weight=float_qparams_observer ) prepare_dynamic(float_embedding) float_embedding(*inputs) if is_emb_bag: q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding) expected_name = "QuantizedEmbeddingBag" else: q_embeddingbag = nnq.Embedding.from_float(float_embedding) expected_name = "QuantizedEmbedding" q_embeddingbag(*inputs) self.assertTrue(expected_name in str(q_embeddingbag)) class QuantizationLiteTestCase(QuantizationTestCase): def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs): # Creates quantized model for testing mobile script modules qengine = "qnnpack" with override_quantized_engine(qengine): # FIXME(rec): shouldn't qconfig be passed to quantize? qconfig = torch.ao.quantization.get_default_qconfig(qengine) # noqa: F841 model = model_class(**kwargs) model = quantize(model, test_only_eval_fn, [self.calib_data]) return model def _compare_script_and_mobile(self, model: torch.nn.Module, input: torch.Tensor): # Compares the numerical outputs for script and lite modules qengine = "qnnpack" with override_quantized_engine(qengine): script_module = torch.jit.script(model) script_module_result = script_module(input) max_retry = 5 for retry in range(1, max_retry + 1): # retries `max_retry` times; breaks iff succeeds else throws exception try: buffer = io.BytesIO( script_module._save_to_buffer_for_lite_interpreter() ) buffer.seek(0) mobile_module = _load_for_lite_interpreter(buffer) mobile_module_result = mobile_module(input) torch.testing.assert_close( script_module_result, mobile_module_result ) mobile_module_forward_result = mobile_module.forward(input) torch.testing.assert_close( script_module_result, mobile_module_forward_result ) mobile_module_run_method_result = mobile_module.run_method( "forward", input ) torch.testing.assert_close( script_module_result, mobile_module_run_method_result ) except AssertionError as e: if retry == max_retry: raise e else: continue break class PT2EQuantizationTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2 with some helper methods. """ _MAP_TO_FX_TRACED_OPS = { torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, } def _test_quantizer( self, model, example_inputs, quantizer, expected_node_occurrence, expected_node_list=None, check_against_fx_quant=False, fx_qconfig_mapping=None, export_with_dynamic_shape=False, is_qat=False, is_debug_mode=False, training_ir_node_occurrence=None, ): # resetting dynamo cache torch._dynamo.reset() m_eager = model.eval() # program capture m = copy.deepcopy(m_eager) dynamic_shapes = tuple( {0: torch.export.Dim("dim")} if i == 0 else None for i in range(len(example_inputs)) ) m = export_for_training( m, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, strict=True, ).module() if is_qat: m = prepare_qat_pt2e(m, quantizer) else: m = prepare_pt2e(m, quantizer) if is_debug_mode: print("prepared model:", m) # Calibrate m(*example_inputs) m = convert_pt2e(m) if is_debug_mode: print("quantized model", m) pt2_quant_output = m(*example_inputs) ns = NodeSpec node_occurrence = { ns.call_function(k): v for k, v in expected_node_occurrence.items() } if expected_node_list is None: expected_node_list = [] node_list = [ns.call_function(n) for n in expected_node_list] self.checkGraphModuleNodes( m, expected_node_occurrence=node_occurrence, expected_node_list=node_list ) if check_against_fx_quant: qconfig_mapping = fx_qconfig_mapping backend_config = get_executorch_backend_config() m_copy = copy.deepcopy(m_eager) m_fx = prepare_fx( m_copy, qconfig_mapping, example_inputs, backend_config=backend_config ) m_fx(*example_inputs) m_fx = _convert_to_reference_decomposed_fx( m_fx, backend_config=backend_config ) m_fx = export_for_training( m_fx, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, strict=True, ).module() node_occurrence = {} for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): if k in expected_node_occurrence: node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] if training_ir_node_occurrence is not None: node_occurrence = { ns.call_function(k): v for k, v in training_ir_node_occurrence.items() } self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) fx_quant_output = m_fx(*example_inputs) self.assertEqual(fx_quant_output, pt2_quant_output) return m def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): # resetting dynamo cache torch._dynamo.reset() m = export_for_training(m, example_inputs, strict=True).module() if is_qat: m = prepare_qat_pt2e(m, quantizer) else: m = prepare_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) return m def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule: class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x): return self.linear(x) quantizer = XNNPACKQuantizer() operator_config = get_symmetric_quantization_config( is_per_channel=is_per_channel ) quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() return self._quantize(m, quantizer, example_inputs) # Below are a series of toy models to use in testing quantization class SingleLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) def forward(self, x): x = self.fc1(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class AnnotatedSingleLayerLinearModel(torch.nn.Module): def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) def forward(self, x): x = self.fc1(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class SingleLayerLinearDynamicModel(torch.nn.Module): def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) def forward(self, x): x = self.fc1(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class LinearAddModel(nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) def forward(self, x): x = self.fc1(x) x = torch.add(x, 5) x = self.fc2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class RNNDynamicModel(torch.nn.Module): def __init__(self, mod_type): super().__init__() self.qconfig = default_dynamic_qconfig if mod_type == "GRU": self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) if mod_type == "LSTM": self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) def forward(self, x): x = self.mod(x) return x class RNNCellDynamicModel(torch.nn.Module): def __init__(self, mod_type): super().__init__() self.qconfig = default_dynamic_qconfig if mod_type == "GRUCell": self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float) if mod_type == "LSTMCell": self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float) if mod_type == "RNNReLU": self.mod = torch.nn.RNNCell(2, 2, nonlinearity="relu").to(dtype=torch.float) if mod_type == "RNNTanh": self.mod = torch.nn.RNNCell(2, 2, nonlinearity="tanh").to(dtype=torch.float) def forward(self, x): x = self.mod(x) return x class LSTMwithHiddenDynamicModel(torch.nn.Module): def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) def forward(self, x, hid): x, hid = self.lstm(x, hid) return x, hid class ConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) def forward(self, x): x = self.conv(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class ConvTransposeModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) def forward(self, x): x = self.conv(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class AnnotatedConvModel(torch.nn.Module): def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv(x) x = self.dequant(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class AnnotatedConvTransposeModel(torch.nn.Module): def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv(x) x = self.dequant(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class ConvBnModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) def forward(self, x): x = self.conv(x) x = self.bn(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class AnnotatedConvBnModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.qconfig = default_qconfig self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv(x) x = self.bn(x) x = self.dequant(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class ConvBnReLUModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class AnnotatedConvBnReLUModel(torch.nn.Module): def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) self.relu = nn.ReLU(inplace=True) self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.dequant(x) return x def fuse_model(self): # TODO: remove this check and define two fuse_modules function on this module if self.training: torch.ao.quantization.fuse_modules_qat( self, [["conv", "bn", "relu"]], inplace=True ) else: torch.ao.quantization.fuse_modules( self, [["conv", "bn", "relu"]], inplace=True ) def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class TwoLayerConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class TwoLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) def forward(self, x): x = self.fc1(x) x = self.fc2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class LinearModelWithSubmodule(nn.Module): def __init__(self) -> None: super().__init__() self.subm = TwoLayerLinearModel() self.fc = nn.Linear(5, 5) def forward(self, x): x = self.subm(x) x = self.fc(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.subm.get_example_inputs() class AnnotatedTwoLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float)) self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") def forward(self, x): x = self.fc1(x) x = self.fc2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class ActivationsTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") self.quant = torch.ao.quantization.QuantStub() self.hardswish = torch.nn.Hardswish().to(dtype=torch.float) self.elu = torch.nn.ELU().to(dtype=torch.float) self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.hardswish(x) x = self.elu(x) x = self.dequant(x) return x class LinearReluModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) self.relu = torch.nn.ReLU() def forward(self, x): x = self.relu(self.fc(x)) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class LinearReluLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class LinearReluAddModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = torch.add(x, 5) x = self.fc2(x) self.relu = torch.nn.ReLU() return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class LinearBnLeakyReluModel(torch.nn.Module): def __init__(self, with_bn=True): super().__init__() self.linear = nn.Linear(5, 5) self.bn1d = nn.BatchNorm1d(5) self.leaky_relu = nn.LeakyReLU(0.01) self.with_bn = with_bn def forward(self, x): x = self.linear(x) if self.with_bn: x = self.bn1d(x) x = self.leaky_relu(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class LinearTanhModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(5, 5) self.tanh = nn.Tanh() def forward(self, x): x = self.linear(x) x = self.tanh(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class ConvBnAddReluModel(torch.nn.Module): def __init__( self, with_bn=True, with_relu=True, left_conv=True, two_conv=True, use_torch_add=True, ): super().__init__() self.conv = nn.Conv2d(5, 5, (2, 2)) self.conv2 = nn.Conv2d(5, 5, (2, 2)) self.bn = nn.BatchNorm2d(5) self.relu = nn.ReLU() self.with_bn = with_bn self.with_relu = with_relu self.two_conv = two_conv self.left_conv = left_conv self.use_torch_add = use_torch_add def forward(self, x1, x2): if self.two_conv: if self.use_torch_add: if self.with_bn: x = torch.add(self.bn(self.conv(x1)), self.conv2(x1)) else: x = torch.add(self.conv(x1), self.conv2(x1)) else: if self.with_bn: x = self.bn(self.conv(x1)) + self.conv2(x1) else: x = self.conv(x1) + self.conv2(x1) else: if self.use_torch_add: if self.left_conv: if self.with_bn: x = torch.add(self.bn(self.conv(x1)), x2) else: x = torch.add(self.conv(x1), x2) else: if self.with_bn: x = torch.add(x2, self.bn(self.conv(x1))) else: x = torch.add(x2, self.conv(x1)) else: if self.left_conv: if self.with_bn: x = self.bn(self.conv(x1)) + x2 else: x = self.conv(x1) + x2 else: if self.with_bn: x = x2 + self.bn(self.conv(x1)) else: x = x2 + self.conv(x1) if self.with_relu: x = self.relu(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2)) # TODO: self.fc should be self.conv class ConvReluModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) self.relu = torch.nn.ReLU() def forward(self, x): x = self.relu(self.fc(x)) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) # TODO: self.fc should be self.conv class ConvReluConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) # TODO: self.fc should be self.conv class ConvReluAddModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = torch.add(x, 5) x = self.fc2(x) self.relu = torch.nn.ReLU() return x def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class NormalizationTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.quant = torch.ao.quantization.QuantStub() self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) self.layer_norm = torch.nn.LayerNorm(8) self.group_norm = torch.nn.GroupNorm(2, 8) self.instance_norm1d = torch.nn.InstanceNorm1d(8) self.instance_norm2d = torch.nn.InstanceNorm2d(8) self.instance_norm3d = torch.nn.InstanceNorm3d(8) def forward(self, x): x = self.quant(x) x = self.fc1(x) x = self.layer_norm(x) x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3)) x = self.instance_norm1d(x) x = self.instance_norm2d(x.unsqueeze(-1)) x = self.instance_norm3d(x.unsqueeze(-1)) return x class NestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sub1 = LinearReluModel() self.sub2 = TwoLayerLinearModel() self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float) def forward(self, x): x = self.sub1(x) x = self.sub2(x) x = self.fc3(x) return x class AnnotatedNestedModel(torch.nn.Module): def __init__(self, qengine): super().__init__() self.sub1 = LinearReluModel() self.sub2 = TwoLayerLinearModel() self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) self.fc3.qconfig = default_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) if qengine == "fbgemm": self.sub2.fc1.qconfig = default_per_channel_qconfig else: self.sub2.fc1.qconfig = default_qconfig def forward(self, x): x = self.sub1(x) x = self.sub2(x) x = self.fc3(x) return x class AnnotatedSubNestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sub1 = LinearReluModel() self.sub2 = QuantWrapper(TwoLayerLinearModel()) self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) self.fc3.qconfig = default_qconfig self.sub2.qconfig = default_qconfig def forward(self, x): x = self.sub1(x) x = self.sub2(x) x = self.fc3(x) return x class AnnotatedCustomConfigNestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sub1 = LinearReluModel() self.sub2 = TwoLayerLinearModel() self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) self.fc3.qconfig = default_qconfig self.sub2.qconfig = default_qconfig custom_options = {"dtype": torch.quint8, "qscheme": torch.per_tensor_affine} custom_qconfig = QConfig( activation=default_observer.with_args(**custom_options), weight=default_weight_observer, ) self.sub2.fc1.qconfig = custom_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) self.sub2.fc2 = QuantWrapper(self.sub2.fc2) def forward(self, x): x = self.sub1(x) x = self.sub2(x) x = self.fc3(x) return x class QuantSubModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.sub1 = LinearReluModel() self.sub2 = QuantWrapper(TwoLayerLinearModel()) self.sub2.qconfig = default_qconfig self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float) self.fc3.qconfig = default_qconfig def forward(self, x): x = self.sub1(x) x = self.sub2(x) x = self.fc3(x) return x class InnerModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) self.relu1 = torch.nn.ReLU() self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) self.relu2 = torch.nn.ReLU() def forward(self, x): return self.relu2(self.fc2(self.relu1(self.fc1(x)))) def fuse_modules(self): fusable_layers = [] named_children = list(self.named_children()) for idx, (current_name, layer) in enumerate(named_children): if isinstance(layer, torch.nn.Linear): if idx >= len(named_children) - 1: break if isinstance(named_children[idx + 1][1], torch.nn.ReLU): fusable_layers.append([current_name, named_children[idx + 1][0]]) # TODO: remove this check and define two fuse_modules function on this module if self.training: torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True) else: torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True) class FunctionalLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = torch.rand((5, 5)) self.bias = torch.zeros(5) def forward(self, x): return F.linear(x, self.weight, self.bias) def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) class SingleLayerFunctionalLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = FunctionalLinear() def forward(self, x): x = self.linear1(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() class TwoLayerFunctionalLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = FunctionalLinear() self.linear2 = FunctionalLinear() def forward(self, x): x = self.linear1(x) x = self.linear2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() class FunctionalLinearAddModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = FunctionalLinear() self.linear2 = FunctionalLinear() def forward(self, x): x = self.linear1(x) x = torch.add(x, 5) x = self.linear2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() class FunctionalLinearReluModel(nn.Module): def __init__(self) -> None: super().__init__() self.linear = FunctionalLinear() def forward(self, x): x = self.linear(x) x = F.relu(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.linear.get_example_inputs() class FunctionalLinearReluLinearModel(nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = FunctionalLinear() self.relu = nn.ReLU() self.linear2 = FunctionalLinear() def forward(self, x): x = self.linear1(x) x = self.relu(x) x = self.linear2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() class FunctionalConv2d(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = torch.rand(3, 3, 3, 3) self.bias = torch.rand(3) self.stride = (1, 1) self.padding = (0, 0) self.dilation = (1, 1) self.groups = 1 def forward(self, x): return F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) class SingleLayerFunctionalConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = FunctionalConv2d() def forward(self, x): x = self.conv1(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() class TwoLayerFunctionalConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = FunctionalConv2d() self.conv2 = FunctionalConv2d() def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() class FunctionalConvReluModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv = FunctionalConv2d() def forward(self, x): x = self.conv(x) x = F.relu(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.conv.get_example_inputs() class FunctionalConvReluConvModel(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = FunctionalConv2d() self.relu = nn.ReLU() self.conv2 = FunctionalConv2d() def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.conv2(x) return x def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() class SkipQuantModel(torch.nn.Module): r"""We can skip quantization by explicitly setting qconfig of a submodule to None """ def __init__(self) -> None: super().__init__() self.sub = InnerModule() self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) def forward(self, x): return self.fc(self.sub(x)) def fuse_modules(self): self.sub.fuse_modules() class AnnotatedSkipQuantModel(torch.nn.Module): r"""We can skip quantization by explicitly setting qconfig of a submodule to None """ def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.sub = QuantWrapper(InnerModule()) self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) # don't quantize this fc self.fc.qconfig = None def forward(self, x): return self.fc(self.sub(x)) def fuse_modules(self): self.sub.module.fuse_modules() class QuantStubModel(torch.nn.Module): r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" def __init__(self) -> None: super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") self.quant = QuantStub() self.dequant = DeQuantStub() self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) def forward(self, x): x = self.quant(x) x = self.fc(x) return self.dequant(x) class ManualLinearQATModel(torch.nn.Module): r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) self.quant = QuantStub() self.dequant = DeQuantStub() self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float) def forward(self, x): x = self.quant(x) x = self.fc1(x) x = self.fc2(x) return self.dequant(x) class ManualDropoutQATModel(torch.nn.Module): r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) self.quant = QuantStub() self.dequant = DeQuantStub() self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) self.dropout = torch.nn.Dropout(0.5) def forward(self, x): x = self.quant(x) x = self.fc1(x) x = self.dropout(x) return self.dequant(x) class ManualLinearDynamicQATModel(torch.nn.Module): r"""A Module that uses a dynamic QAT by default.""" def __init__(self, qconfig=None): super().__init__() self.qconfig = qconfig or default_dynamic_qat_qconfig self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float) def forward(self, x): x = self.fc1(x) x = self.fc2(x) return x class ManualConvLinearQATModel(torch.nn.Module): r"""A module with manually inserted `QuantStub` and `DeQuantStub` and contains both linear and conv modules """ def __init__(self, qconfig=None): super().__init__() self.qconfig = ( qconfig if qconfig else torch.ao.quantization.get_default_qat_qconfig("qnnpack") ) self.quant = QuantStub() self.dequant = DeQuantStub() self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float) self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float) self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float) def forward(self, x): x = self.quant(x) x = self.conv(x) x = x.view(-1, 64).contiguous() x = self.fc1(x) x = self.fc2(x) return self.dequant(x) class ManualConvLinearSymmQATModel(ManualConvLinearQATModel): r"""Same as ManualConvLinearQATModule but with Symmetric Quantization. Supported only with qnnpack. """ def __init__(self) -> None: super().__init__(default_symmetric_qnnpack_qat_qconfig) class ManualEmbeddingBagLinear(nn.Module): def __init__(self) -> None: super().__init__() self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode="sum") self.emb.qconfig = default_embedding_qat_qconfig self.quant = QuantStub() self.dequant = DeQuantStub() self.linear = nn.Linear(12, 1).to(dtype=torch.float) self.qconfig = get_default_qat_qconfig("qnnpack") def forward( self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, per_sample_weights: Optional[torch.Tensor] = None, ): x = self.emb(input, offsets, per_sample_weights) x = self.quant(x) x = self.linear(x) return self.dequant(x) class DeFusedEmbeddingBagLinear(nn.Module): r"""A module to simulate QAT embedding bag with a linear layer, this module uses a separate embedding and bagging op, similar to that which is described in the EmbeddingBag documentation. https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html """ def __init__(self) -> None: super().__init__() self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12) self.emb.qconfig = default_embedding_qat_qconfig self.bagging_op = torch.sum self.quant = QuantStub() self.dequant = DeQuantStub() self.linear = nn.Linear(12, 1).to(dtype=torch.float) self.qconfig = get_default_qat_qconfig("qnnpack") def forward(self, input: torch.Tensor) -> torch.Tensor: x = self.bagging_op(self.emb(input), dim=1) x = self.quant(x) x = self.linear(x) return self.dequant(x) class SubModelForFusion(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float) self.bn = nn.BatchNorm2d(2).to(dtype=torch.float) def forward(self, x): x = self.conv(x) x = self.bn(x) return x class SubModelWithoutFusion(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float) self.relu = nn.ReLU(inplace=False).to(dtype=torch.float) def forward(self, x): return self.relu(self.conv(x)) class ModelForFusion(nn.Module): def __init__(self, qconfig): super().__init__() self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float) self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float) self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float) self.sub1 = SubModelForFusion() self.sub2 = SubModelWithoutFusion() self.fc = nn.Linear(36, 10).to(dtype=torch.float) self.quant = QuantStub() self.dequant = DeQuantStub() self.qconfig = qconfig self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float) self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float) self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float) self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float) self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float) self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float) self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float) # don't quantize sub2 self.sub2.qconfig = None self.fc.qconfig = None def forward(self, x): x = x.squeeze(2) x = self.quant(x) x = self.conv3(x) x = self.bn3(x) x = self.relu4(x) x = x.unsqueeze(2) y = x.unsqueeze(2) x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.sub1(x) x = self.dequant(x) x = self.sub2(x) x = x.reshape(-1, 36).contiguous() x = self.fc(x) y = self.conv2(y) y = self.relu2(y) y = self.bn2(y) y = self.relu3(y) y = self.dequant(y) return x class ConvBNReLU(nn.Sequential): def __init__(self) -> None: super().__init__( nn.Conv2d(3, 3, 1, 1, bias=False), nn.BatchNorm2d(3), nn.ReLU(inplace=False) ) class ModelWithSequentialFusion(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(3, 3, 1) self.relu1 = nn.ReLU(inplace=False) layers = [ConvBNReLU() for _ in range(3)] self.features = nn.Sequential(*layers) head = [nn.Linear(300, 10), nn.ReLU(inplace=False)] self.classifier = nn.Sequential(*head) self.seq = nn.Sequential() self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.relu1(x) x = self.features(x) x = torch.reshape(x, (-1, 3 * 10 * 10)) x = self.classifier(x) x = self.seq(x) x = self.dequant(x) return x class ModelForFusionWithBias(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float) self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float) self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float) self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float) self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float) self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) x = self.bn2(x) x = self.dequant(x) return x class ModelForLinearBNFusion(nn.Module): def __init__(self) -> None: super().__init__() self.fc = nn.Linear(20, 10) self.bn = nn.BatchNorm1d(10) nn.init.uniform_(self.bn.weight) nn.init.uniform_(self.bn.bias) def forward(self, x): return self.bn(self.fc(x)) class DummyObserver(torch.nn.Module): def calculate_qparams(self): return 1.0, 0 def forward(self, x): return x class ModelForConvTransposeBNFusion(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.ConvTranspose1d(3, 3, 1) self.bn1 = nn.BatchNorm1d(3) self.conv2 = nn.ConvTranspose2d(3, 3, 1) self.bn2 = nn.BatchNorm2d(3) self.conv3 = nn.ConvTranspose3d(3, 3, 1) self.bn3 = nn.BatchNorm3d(3) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = x.unsqueeze(2) x = self.conv2(x) x = self.bn2(x) x = x.unsqueeze(2) x = self.conv3(x) x = self.bn3(x) return x class ModelWithFunctionals(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mycat = nnq.FloatFunctional() self.myadd = nnq.FloatFunctional() self.myadd_relu = nnq.FloatFunctional() self.mymatmul = nnq.FloatFunctional() # Tracing doesn't work yet for c10 ops with scalar inputs # https://github.com/pytorch/pytorch/issues/27097 # self.my_scalar_add = nnq.FloatFunctional() # self.my_scalar_mul = nnq.FloatFunctional() def forward(self, x): y = self.mycat.cat([x, x, x]) z = self.myadd.add(y, y) w = self.myadd_relu.add_relu(z, z) u = self.mymatmul.matmul(w, w.T) # Tracing doesn't work yet for c10 ops with scalar inputs # https://github.com/pytorch/pytorch/issues/27097 # w = self.my_scalar_add.add_scalar(w, -0.5) # w = self.my_scalar_mul.mul_scalar(w, 0.5) return u class ResNetBase(torch.nn.Module): def __init__(self) -> None: super().__init__() norm_layer = nn.BatchNorm2d inplanes = 3 self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) self.bn1 = norm_layer(inplanes) self.relu1 = nn.ReLU() self.relu2 = nn.ReLU() self.downsample = torch.nn.Identity() self.myop = nn.quantized.FloatFunctional() self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = torch.nn.Linear(inplanes, 1) def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) identity = self.downsample(x) out = self.myop.add(out, identity) out = self.relu2(out) out = self.avgpool(out) out = torch.flatten(out, 1) out = self.fc(out) return out def fuse_model(self): # TODO: remove this check and define two fuse_model function on this module if self.training: torch.ao.quantization.fuse_modules_qat( self, [["conv1", "bn1", "relu1"]], inplace=True ) else: torch.ao.quantization.fuse_modules( self, [["conv1", "bn1", "relu1"]], inplace=True ) class ModelMultipleOps(torch.nn.Module): def __init__(self) -> None: super().__init__() norm_layer = nn.BatchNorm2d inplanes = 3 self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) self.bn1 = norm_layer(inplanes) self.relu1 = nn.ReLU() self.relu2 = nn.ReLU() self.downsample = torch.nn.Identity() self.skip_add = nn.quantized.FloatFunctional() self.cat = nn.quantized.FloatFunctional() self.avgpool = nn.AdaptiveAvgPool2d((4, 4)) self.fc = nn.Linear(12, 6) def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) identity = self.downsample(x) out = self.skip_add.add(out, identity) out = self.relu2(out) out = self.avgpool(out) out = self.conv2(out) out = torch.nn.functional.max_pool2d(out, 2, 2) out = self.cat.cat([out, out]) out = out.reshape(-1, 3 * 2 * 2) out = self.fc(out) return out # Model to ensure consistency of fake quant with true quant # Average pooling and mean operations are not modelled # accurately with fake-quant so this model does not # contain those operations class ModelMultipleOpsNoAvgPool(torch.nn.Module): def __init__(self) -> None: super().__init__() norm_layer = nn.BatchNorm2d inplanes = 3 self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) self.bn1 = norm_layer(inplanes) self.relu1 = nn.ReLU() self.relu2 = nn.ReLU() self.skip_add = nn.quantized.FloatFunctional() self.cat = nn.quantized.FloatFunctional() self.maxpool = nn.MaxPool2d((4, 4)) self.fc = nn.Linear(12, 6) def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) skip = self.conv2(x) out = self.skip_add.add(out, skip) out = self.relu2(out) out = self.maxpool(out) out = self.conv2(out) out = torch.nn.functional.max_pool2d(out, 2, 2) out = self.cat.cat([out, out]) out = out.reshape(-1, 3 * 2 * 2) out = self.fc(out) return out class EmbeddingBagModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.EmbeddingBag( num_embeddings=10, embedding_dim=12, include_last_offset=True, scale_grad_by_freq=False, mode="sum", ) def forward(self, indices, offsets, per_sample_weights): return self.emb(indices, offsets, per_sample_weights) class EmbeddingModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) def forward(self, indices): return self.emb(indices) class EmbeddingWithStaticLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12) self.fc = torch.nn.Linear(4, 2) self.emb.qconfig = float_qparams_weight_only_qconfig self.qconfig = default_qconfig self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, indices, offsets, linear_in): emb = self.emb(indices, offsets) q_x = self.quant(linear_in) fc = self.fc(q_x) fc = self.dequant(fc) features = torch.cat([fc] + [emb], dim=1) return features class DenseTopMLP(nn.Module): def __init__( self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out ) -> None: super().__init__() self.dense_mlp = nn.Sequential( nn.Linear(dense_dim, dense_out), ) self.top_mlp = nn.Sequential( nn.Linear(dense_out + embedding_dim, top_out_in), nn.Linear(top_out_in, top_out_out), ) def forward( self, sparse_feature: torch.Tensor, dense: torch.Tensor, ) -> torch.Tensor: dense_feature = self.dense_mlp(dense) features = torch.cat([dense_feature] + [sparse_feature], dim=1) out = self.top_mlp(features) return out # thin wrapper around embedding bag, because tracing inside nn.Embedding # bag is not supported at the moment and this is top level class EmbBagWrapper(nn.Module): def __init__(self, num_embeddings, embedding_dim): super().__init__() self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum") def forward(self, indices, offsets): return self.emb_bag(indices, offsets) class SparseNNModel(nn.Module): _NUM_EMBEDDINGS = 10 _EMBEDDING_DIM = 5 _DENSE_DIM = 4 _DENSE_OUTPUT = 2 _TOP_OUT_IN = 2 _TOP_OUT_OUT = 2 _TOP_MLP_DIM = 1 def __init__(self) -> None: super().__init__() self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM) self.dense_top = DenseTopMLP( self._DENSE_DIM, self._DENSE_OUTPUT, self._EMBEDDING_DIM, self._TOP_OUT_IN, self._TOP_OUT_OUT, ) def forward( self, sparse_indices: torch.Tensor, sparse_offsets: torch.Tensor, dense: torch.Tensor, ) -> torch.Tensor: sparse_feature = self.model_sparse(sparse_indices, sparse_offsets) out = self.dense_top(sparse_feature, dense) return out class TestHelperModules: class ControlFlow(torch.nn.Module): def forward( self, xs: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: def true_nested(y: torch.Tensor) -> torch.Tensor: y = y + y y = torch.mm(y, y) return y def false_nested(y: torch.Tensor) -> torch.Tensor: return torch.mm(y, y) def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor: z = control_flow.cond(pred2, true_nested, false_nested, [x]) return x + z def false_fn(x: torch.Tensor, _) -> torch.Tensor: return x.cos() def map_fn( x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: x = x.cos() y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) x = x + y return x.sin() y = torch.mm(y, y) return control_flow.map(map_fn, xs, pred1, pred2, y) def example_inputs(self): return ( torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2), ) class Conv2dPropAnnotaton(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.linear = torch.nn.Linear(3, 3) def forward(self, x): x = self.conv(x) x = x.view(-1, 3) x = torch.nn.functional.hardtanh(x, -0.5, 0.5) x = self.linear(x) return x class Conv2dWithObsSharingOps(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.hardtanh = torch.nn.Hardtanh() self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) def forward(self, x): x = self.conv(x) x = self.adaptive_avg_pool2d(x) x = self.hardtanh(x) x = torch.mean(x) return x class Conv2dWithTwoLinearPermute(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 16, 3) self.linear1 = torch.nn.Linear(16, 8, bias=False) self.linear2 = torch.nn.Linear(8, 8) def forward(self, x): conv_out = self.conv(x) permute_out = torch.permute(conv_out, (0, 2, 3, 1)) return self.linear2(self.linear1(permute_out)) class Conv2dWithTwoLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 16, 3) self.linear1 = torch.nn.Linear(64, 8, bias=False) self.linear2 = torch.nn.Linear(8, 8) def forward(self, x): conv_out = self.conv(x) reshape_out = torch.reshape(conv_out, (2, 64)) return self.linear2(self.linear1(reshape_out)) class ConvLinearWPermute(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 8, 3) self.linear1 = torch.nn.Linear(8, 8) def forward(self, x): conv_out = self.conv(x) permute_out = torch.permute(conv_out, (0, 2, 3, 1)) return self.linear1(permute_out) class TwoLinearModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(8, 16, bias=False) self.linear2 = torch.nn.Linear(16, 8) def forward(self, x): return self.linear2(self.linear1(x)) def example_inputs(self): return (torch.randn(2, 8),) class ConvMaxPool2d(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(2, 2, 1) self.pool = torch.nn.MaxPool2d(1, 1) def forward(self, x): x = self.conv(x) x = self.pool(x) return x class ConvWithAdaptiveAvgPool2d(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) def forward(self, x): x = self.conv(x) x = self.adaptive_avg_pool2d(x) return x class ConvWithBNRelu(torch.nn.Module): def __init__(self, relu, dim=2, bn=True, bias=True, padding=0): super().__init__() convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding) if bn: self.bn = bns[dim](3) else: self.bn = torch.nn.Identity() if relu: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): x = self.conv(x) x = self.bn(x) return self.relu(x) class ConvTWithBNRelu(torch.nn.Module): def __init__(self, relu, dim=2, bn=True, bias=True): super().__init__() convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d} bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d} self.convt = convts[dim](3, 3, 3, bias=bias) if bn: self.bn = bns[dim](3) else: self.bn = torch.nn.Identity() if relu: self.relu = torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): x = self.convt(x) x = self.bn(x) return self.relu(x) class Conv2dThenConv1d(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1d = torch.nn.Conv1d(3, 3, 3) self.conv2d = torch.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv2d(x) x = x.squeeze(0) x = self.conv1d(x) return x def example_inputs(self): return (torch.randn(1, 3, 5, 5),) class Conv2dWithCat(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.conv2 = torch.nn.Conv2d(3, 3, 3) def forward(self, x, y): x = self.conv1(x) y = self.conv2(y) z = torch.cat([x, y], dim=1) return z class Conv2dWithTwoCat(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.conv2 = torch.nn.Conv2d(3, 3, 3) def forward(self, x1, x2, x3, x4): x1 = self.conv1(x1) x2 = self.conv2(x2) y = torch.cat([x1, x2], dim=1) z = x3 + x4 w = torch.cat([z, y]) return w class Conv2dWithSplit(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.conv2 = torch.nn.Conv2d(3, 3, 3) def forward(self, x): x = self.conv1(x) # use split so we get a list of Tensors x1, x2 = torch.split(x, 2, dim=1) y = torch.cat([x1, x2], dim=1) return y def example_inputs(self): return (torch.randn(1, 3, 16, 16),) class ThreeAdd(torch.nn.Module): def forward(self, x1, x2, x3, x4): y = x1 + x2 z = x3 + x4 w = y + z return w class EmbeddingModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) def forward(self, indices): return self.emb(indices) class EmbeddingConvLinearModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) self.conv = torch.nn.Conv2d(8, 16, (1, 3)) self.linear = torch.nn.Linear(16, 8) def forward(self, indices): embeddings = self.emb(indices) embeddings = torch.unsqueeze(embeddings, dim=0) embeddings = torch.permute(embeddings, (0, 3, 1, 2)) conv_out = self.conv(embeddings) conv_out = torch.permute(conv_out, (0, 2, 3, 1)) conv_out = torch.squeeze(conv_out, dim=0) return self.linear(conv_out) class AddInplaceAdd(torch.nn.Module): def forward(self, x, y): x = x + y x += y return x class MulInplaceMul(torch.nn.Module): def forward(self, x, y): x = x * y x *= y return x class AddMulScalar(torch.nn.Module): def forward(self, x): x = x + 3 x = x * 3 x += 3 x *= 3 return x class ConvBnReLU2dAndLinearReLU(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True) self.linear = torch.nn.Linear(3, 8, bias=False) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv_bn_relu(x) permute_out = torch.permute(x, (0, 2, 3, 1)) linear_out = self.linear(permute_out) return linear_out class GroupwiseConv2d(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(4, 4, 3, groups=2) def forward(self, x): return self.conv(x) def example_inputs(self): return (torch.randn(2, 4, 10, 10),) class LinearReluModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) self.relu = torch.nn.ReLU() def forward(self, x): x = self.relu(self.fc(x)) return x def _generate_qdq_quantized_model( mod, inputs, is_qat=False, is_dynamic=False, quantizer=None ): def get_default_quantizer(is_qat, is_dynamic, inputs): has_xpu = any( isinstance(input, torch.Tensor) and input.device.type == "xpu" for input in inputs ) if has_xpu: quantizer = XPUInductorQuantizer() assert (not is_qat) and ( not is_dynamic ), "QAT and dynamic quantization is not supported at XPU backend currently" quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config()) else: quantizer = X86InductorQuantizer() quantizer.set_global( xiq.get_default_x86_inductor_quantization_config( is_qat=is_qat, is_dynamic=is_dynamic ) ) return quantizer maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: export_model = export_for_training(mod, inputs, strict=True).module() quantizer = ( quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic, inputs) ) prepare_model = ( prepare_qat_pt2e(export_model, quantizer) if is_qat else prepare_pt2e(export_model, quantizer) ) prepare_model(*inputs) torch.ao.quantization.move_exported_model_to_eval(prepare_model) convert_model = convert_pt2e(prepare_model) return convert_model