[AO] Update qLSTM implementation to remove unsupported backend ops (#96436)

Summary:
The reference quantized LSTM implementation uses unbind and inplace squeeze both of which are not supported when building BoltNN's Espresso IR graph.

This change adjusts the reference AO Quantizable LSTM implementation without affecting numerically while enabling removal of unsupported ops in BoltNN.

Modifications & Adjustments
1. Unbind ops appear when unstacking tensor in loop. Replaced this by getting first dim from shape and looping using ranged index.
2. Removed unbind ops call where the pattern is
`[x = t.unbind(0) -> x[i]]` can be just replaced by `t[i]` as creating a tuple from unbind is unnecessary.
3. inplace squeeze `squeeze_` uses which were not required has been replaced by `squeeze`.

See notebook N3235193 which was used for testing quantization flow and inspect the torch scripted quantized model for the set of ops used(See last cell).

Test Plan: N3235193

Reviewed By: andrewor14

Differential Revision: D43935389

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96436
Approved by: https://github.com/andrewor14
This commit is contained in:
Nitin Jain
2023-03-14 17:58:34 +00:00
committed by PyTorch MergeBot
parent 7ec0d6f006
commit 40df3b41aa

View File

@ -147,8 +147,9 @@ class _LSTMSingleLayer(torch.nn.Module):
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
result = []
for xx in x:
hidden = self.cell(xx, hidden)
seq_len = x.shape[0]
for i in range(seq_len):
hidden = self.cell(x[i], hidden)
result.append(hidden[0]) # type: ignore[index]
result_tensor = torch.stack(result, 0)
return result_tensor, hidden
@ -350,11 +351,11 @@ class LSTM(torch.nn.Module):
if isinstance(hidden_non_opt[0], Tensor):
hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
self.hidden_size)
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
self.hidden_size)
hxcx = [(hx[idx].squeeze(0), cx[idx].squeeze(0)) for idx in range(self.num_layers)]
else:
hxcx = hidden_non_opt