Revert "Add option to split Linear gates for Quantizable LSTM into separate ops (#140868)"

This reverts commit 3fcf66f61fbc8f760fc0d34356a60b76c3f2e27c.

Reverted https://github.com/pytorch/pytorch/pull/140868 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think lint is failing on this in trunk ([comment](https://github.com/pytorch/pytorch/pull/140868#issuecomment-2494076202))
This commit is contained in:
PyTorch MergeBot
2024-11-22 15:54:05 +00:00
parent 080f992d68
commit cf1d95a965
4 changed files with 48 additions and 183 deletions

View File

@ -2917,11 +2917,6 @@ class TestQuantizedOps(TestCase):
@override_qengines
def test_custom_module_lstm(self):
class QuantizableLSTMSplitGates(torch.ao.nn.quantizable.LSTM):
@classmethod
def from_float(cls, other, qconfig=None):
return super().from_float(other, qconfig, split_gates=True)
qengine = torch.backends.quantized.engine
batch_size = 4
@ -2936,7 +2931,6 @@ class TestQuantizedOps(TestCase):
Bias = [False, True]
Batch_first = [False, True]
Bidirectional = [False, True]
Split_gates = [False, True]
dtype = np.uint8
qtype = torch.quint8
@ -2949,8 +2943,8 @@ class TestQuantizedOps(TestCase):
x = qx.dequantize()
with torch.no_grad():
for bias, batch_first, bidirectional, split_gates in itertools.product(
Bias, Batch_first, Bidirectional, Split_gates):
for bias, batch_first, bidirectional in itertools.product(
Bias, Batch_first, Bidirectional):
# Assume 12dB is sufficient for functional equivalence
# Without the bias, linear performs poorly
min_power = 10 if bias else 5
@ -2974,36 +2968,17 @@ class TestQuantizedOps(TestCase):
# Prepare
lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
custom_config_dict = (
None
if not split_gates
else { # switch to class with split_gates True via from_float
"float_to_observed_custom_module_class": {
torch.nn.LSTM: QuantizableLSTMSplitGates
},
"observed_to_quantized_custom_module_class": {
QuantizableLSTMSplitGates: torch.ao.nn.quantized.LSTM,
},
}
)
lstm_prepared = torch.ao.quantization.prepare(
lstm, prepare_custom_config_dict=custom_config_dict
)
lstm_prepared = torch.ao.quantization.prepare(lstm)
self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
self.assertEqual(num_layers, len(lstm_prepared[0].layers))
self.assertEqual(
lstm_prepared[0].layers[0].layer_fw.cell.split_gates, split_gates
)
assert isinstance(lstm_prepared[0], torch.ao.nn.quantizable.LSTM)
assert type(lstm_prepared[0]) == torch.ao.nn.quantizable.LSTM
# Calibrate
y = lstm_prepared(x)
self.assertEqual(y_ref, y)
# Quantize
lstm_quantized = torch.ao.quantization.convert(
lstm_prepared, convert_custom_config_dict=custom_config_dict
)
lstm_quantized = torch.ao.quantization.convert(lstm_prepared)
assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
qy = lstm_quantized(qx)