mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AO]Fix observed LSTM layer setup individually observed LSTM (#101299)
Summary: We have found that `_get_lstm_with_individually_observed_parts()` is missing setup step which sets up the LSTM layer state initializing weights and biases of this layer. This diff fixes the observed numerical discrepancy seen by CTRL team in using the above API. Test Plan: N3358643 Differential Revision: D45821681 Pull Request resolved: https://github.com/pytorch/pytorch/pull/101299 Approved by: https://github.com/andrewor14
This commit is contained in:
committed by
PyTorch MergeBot
parent
2fa1b563da
commit
556bb691fd
@ -392,6 +392,7 @@ class LSTM(torch.nn.Module):
|
||||
for idx in range(other.num_layers):
|
||||
observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
|
||||
batch_first=False)
|
||||
# TODO: Remove setting observed to eval to enable QAT.
|
||||
observed.eval()
|
||||
observed = torch.ao.quantization.prepare(observed, inplace=True)
|
||||
return observed
|
||||
|
@ -72,6 +72,12 @@ def _get_lstm_with_individually_observed_parts(
|
||||
float_lstm.batch_first, float_lstm.dropout, float_lstm.bidirectional)
|
||||
quantizable_lstm.qconfig = float_lstm.qconfig
|
||||
|
||||
for idx in range(float_lstm.num_layers):
|
||||
quantizable_lstm.layers[idx] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(float_lstm,
|
||||
idx,
|
||||
float_lstm.qconfig,
|
||||
batch_first=False)
|
||||
|
||||
# Build QConfigMapping for the LSTM cell
|
||||
# Note: FloatFunctional qconfigs will be configured separately below
|
||||
cell_qm = QConfigMapping().set_global(float_lstm.qconfig) # type: ignore[arg-type]
|
||||
|
Reference in New Issue
Block a user