mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add option to split Linear gates for Quantizable LSTM into separate ops (#140868)"
This reverts commit 3fcf66f61fbc8f760fc0d34356a60b76c3f2e27c. Reverted https://github.com/pytorch/pytorch/pull/140868 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think lint is failing on this in trunk ([comment](https://github.com/pytorch/pytorch/pull/140868#issuecomment-2494076202))
This commit is contained in:
@ -2917,11 +2917,6 @@ class TestQuantizedOps(TestCase):
|
||||
|
||||
@override_qengines
|
||||
def test_custom_module_lstm(self):
|
||||
class QuantizableLSTMSplitGates(torch.ao.nn.quantizable.LSTM):
|
||||
@classmethod
|
||||
def from_float(cls, other, qconfig=None):
|
||||
return super().from_float(other, qconfig, split_gates=True)
|
||||
|
||||
qengine = torch.backends.quantized.engine
|
||||
|
||||
batch_size = 4
|
||||
@ -2936,7 +2931,6 @@ class TestQuantizedOps(TestCase):
|
||||
Bias = [False, True]
|
||||
Batch_first = [False, True]
|
||||
Bidirectional = [False, True]
|
||||
Split_gates = [False, True]
|
||||
|
||||
dtype = np.uint8
|
||||
qtype = torch.quint8
|
||||
@ -2949,8 +2943,8 @@ class TestQuantizedOps(TestCase):
|
||||
x = qx.dequantize()
|
||||
|
||||
with torch.no_grad():
|
||||
for bias, batch_first, bidirectional, split_gates in itertools.product(
|
||||
Bias, Batch_first, Bidirectional, Split_gates):
|
||||
for bias, batch_first, bidirectional in itertools.product(
|
||||
Bias, Batch_first, Bidirectional):
|
||||
# Assume 12dB is sufficient for functional equivalence
|
||||
# Without the bias, linear performs poorly
|
||||
min_power = 10 if bias else 5
|
||||
@ -2974,36 +2968,17 @@ class TestQuantizedOps(TestCase):
|
||||
|
||||
# Prepare
|
||||
lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
||||
custom_config_dict = (
|
||||
None
|
||||
if not split_gates
|
||||
else { # switch to class with split_gates True via from_float
|
||||
"float_to_observed_custom_module_class": {
|
||||
torch.nn.LSTM: QuantizableLSTMSplitGates
|
||||
},
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
QuantizableLSTMSplitGates: torch.ao.nn.quantized.LSTM,
|
||||
},
|
||||
}
|
||||
)
|
||||
lstm_prepared = torch.ao.quantization.prepare(
|
||||
lstm, prepare_custom_config_dict=custom_config_dict
|
||||
)
|
||||
lstm_prepared = torch.ao.quantization.prepare(lstm)
|
||||
self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
|
||||
self.assertEqual(num_layers, len(lstm_prepared[0].layers))
|
||||
self.assertEqual(
|
||||
lstm_prepared[0].layers[0].layer_fw.cell.split_gates, split_gates
|
||||
)
|
||||
assert isinstance(lstm_prepared[0], torch.ao.nn.quantizable.LSTM)
|
||||
assert type(lstm_prepared[0]) == torch.ao.nn.quantizable.LSTM
|
||||
|
||||
# Calibrate
|
||||
y = lstm_prepared(x)
|
||||
self.assertEqual(y_ref, y)
|
||||
|
||||
# Quantize
|
||||
lstm_quantized = torch.ao.quantization.convert(
|
||||
lstm_prepared, convert_custom_config_dict=custom_config_dict
|
||||
)
|
||||
lstm_quantized = torch.ao.quantization.convert(lstm_prepared)
|
||||
assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
|
||||
qy = lstm_quantized(qx)
|
||||
|
||||
|
Reference in New Issue
Block a user