[AO]Fix observed LSTM layer setup individually observed LSTM (#101299)

Summary: We have found that `_get_lstm_with_individually_observed_parts()` is missing setup step which sets up the LSTM layer state initializing weights and biases of this layer. This diff fixes the observed numerical discrepancy seen by CTRL team in using the above API.

Test Plan: N3358643

Differential Revision: D45821681

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101299
Approved by: https://github.com/andrewor14
This commit is contained in:
Nitin Jain
2023-05-18 19:15:01 +00:00
committed by PyTorch MergeBot
parent 2fa1b563da
commit 556bb691fd
2 changed files with 7 additions and 0 deletions

View File

@ -392,6 +392,7 @@ class LSTM(torch.nn.Module):
for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
batch_first=False)
# TODO: Remove setting observed to eval to enable QAT.
observed.eval()
observed = torch.ao.quantization.prepare(observed, inplace=True)
return observed

View File

@ -72,6 +72,12 @@ def _get_lstm_with_individually_observed_parts(
float_lstm.batch_first, float_lstm.dropout, float_lstm.bidirectional)
quantizable_lstm.qconfig = float_lstm.qconfig
for idx in range(float_lstm.num_layers):
quantizable_lstm.layers[idx] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(float_lstm,
idx,
float_lstm.qconfig,
batch_first=False)
# Build QConfigMapping for the LSTM cell
# Note: FloatFunctional qconfigs will be configured separately below
cell_qm = QConfigMapping().set_global(float_lstm.qconfig) # type: ignore[arg-type]