Add proper casting to fuse_linear_bn_weights (#134105)

As per title, this PR adds proper casting to fuse_linear_bn_weights in the same style as the conv case above. This previously caused numerical issues on my end, so that is why I am fixing it.

Also cleans up the docstring.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134105
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Vladimir Monakhov
2024-08-22 14:26:12 +00:00
committed by PyTorch MergeBot
parent b459ca78eb
commit afc2615d33

View File

@ -172,17 +172,18 @@ def fuse_linear_bn_weights(
bn_eps (float): BatchNorm epsilon.
bn_w (torch.Tensor): BatchNorm weight.
bn_b (torch.Tensor): BatchNorm bias.
transpose (bool, optional): If True, transpose the conv weight. Defaults to False.
Returns:
Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias.
"""
linear_weight_dtype = linear_w.dtype
linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype
if linear_b is None:
linear_b = torch.zeros_like(bn_rm)
bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)
fused_w = linear_w * bn_scale.unsqueeze(-1)
fused_b = (linear_b - bn_rm) * bn_scale + bn_b
fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype)
fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype)
return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(
fused_b, linear_b.requires_grad