mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Some minor fix to unblock the Bert model quantization (#20787)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20787 Set requires_grad=False for bias: this will block the jit tracing. The as_type fix: The input tensor shape and output tensor shape will be different, which will trigger the assertion failure at https://fburl.com/0m8xy7tc. Reviewed By: jamesr66a Differential Revision: D15445092 fbshipit-source-id: 22da41a56ecb9ac092585d0cc1ff0658fb9d631b
This commit is contained in:
committed by
Facebook Github Bot
parent
a501e7d5be
commit
70ecddfd76
@ -20,7 +20,7 @@ class QuantizedLinear(torch.jit.ScriptModule):
|
||||
self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
|
||||
self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
|
||||
assert other.bias is not None, 'QuantizedLinear requires a bias'
|
||||
self.bias = torch.nn.Parameter(other.bias.clone().float())
|
||||
self.bias = torch.nn.Parameter(other.bias.clone().float(), requires_grad=False)
|
||||
|
||||
self.register_buffer(
|
||||
'packed_tensor_ptr',
|
||||
@ -42,7 +42,7 @@ class QuantizedLinear(torch.jit.ScriptModule):
|
||||
out = torch.fbgemm_linear_int8_weight(
|
||||
input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets,
|
||||
self.scale, self.zero_point, self.bias)
|
||||
return out.type_as(input)
|
||||
return out.float()
|
||||
|
||||
def extra_repr(self):
|
||||
repr = 'in_features={in_features}, out_features={out_features}, ' \
|
||||
|
Reference in New Issue
Block a user