mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add proper casting to fuse_linear_bn_weights (#134105)
As per title, this PR adds proper casting to fuse_linear_bn_weights in the same style as the conv case above. This previously caused numerical issues on my end, so that is why I am fixing it. Also cleans up the docstring. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134105 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
b459ca78eb
commit
afc2615d33
@ -172,17 +172,18 @@ def fuse_linear_bn_weights(
|
||||
bn_eps (float): BatchNorm epsilon.
|
||||
bn_w (torch.Tensor): BatchNorm weight.
|
||||
bn_b (torch.Tensor): BatchNorm bias.
|
||||
transpose (bool, optional): If True, transpose the conv weight. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias.
|
||||
"""
|
||||
linear_weight_dtype = linear_w.dtype
|
||||
linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype
|
||||
if linear_b is None:
|
||||
linear_b = torch.zeros_like(bn_rm)
|
||||
bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)
|
||||
|
||||
fused_w = linear_w * bn_scale.unsqueeze(-1)
|
||||
fused_b = (linear_b - bn_rm) * bn_scale + bn_b
|
||||
fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype)
|
||||
fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype)
|
||||
|
||||
return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(
|
||||
fused_b, linear_b.requires_grad
|
||||
|
Reference in New Issue
Block a user