diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index 8236a9ca7e4b..8b71ef85d678 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -20,7 +20,7 @@ class QuantizedLinear(torch.jit.ScriptModule): self.weight = torch.nn.Parameter(self.weight, requires_grad=False) self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False) assert other.bias is not None, 'QuantizedLinear requires a bias' - self.bias = torch.nn.Parameter(other.bias.clone().float()) + self.bias = torch.nn.Parameter(other.bias.clone().float(), requires_grad=False) self.register_buffer( 'packed_tensor_ptr', @@ -42,7 +42,7 @@ class QuantizedLinear(torch.jit.ScriptModule): out = torch.fbgemm_linear_int8_weight( input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets, self.scale, self.zero_point, self.bias) - return out.type_as(input) + return out.to(input.dtype) def extra_repr(self): repr = 'in_features={in_features}, out_features={out_features}, ' \